mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
Disable trace_on_chain_group auto-tracing (#12807)
Previously we treated trace_on_chain_group as a command to always start tracing. This is unintuitive (makes the function do 2 things), and makes it harder to toggle tracing
This commit is contained in:
parent
0da75b9ebd
commit
18005c6384
@ -124,9 +124,11 @@ def tracing_enabled(
|
||||
"""
|
||||
cb = LangChainTracerV1()
|
||||
session = cast(TracerSessionV1, cb.load_session(session_name))
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
tracing_callback_var.set(None)
|
||||
try:
|
||||
tracing_callback_var.set(cb)
|
||||
yield session
|
||||
finally:
|
||||
tracing_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -191,9 +193,11 @@ def tracing_v2_enabled(
|
||||
tags=tags,
|
||||
client=client,
|
||||
)
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield cb
|
||||
tracing_v2_callback_var.set(None)
|
||||
try:
|
||||
tracing_v2_callback_var.set(cb)
|
||||
yield cb
|
||||
finally:
|
||||
tracing_v2_callback_var.set(None)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@ -214,6 +218,33 @@ def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None,
|
||||
run_collector_var.set(None)
|
||||
|
||||
|
||||
def _get_trace_callbacks(
|
||||
project_name: Optional[str] = None,
|
||||
example_id: Optional[Union[str, UUID]] = None,
|
||||
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
|
||||
) -> Callbacks:
|
||||
if _tracing_v2_is_enabled():
|
||||
project_name_ = project_name or _get_tracer_project()
|
||||
tracer = tracing_v2_callback_var.get() or LangChainTracer(
|
||||
project_name=project_name_,
|
||||
example_id=example_id,
|
||||
)
|
||||
if callback_manager is None:
|
||||
cb = cast(Callbacks, [tracer])
|
||||
else:
|
||||
if not any(
|
||||
isinstance(handler, LangChainTracer)
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(tracer, True)
|
||||
# If it already has a LangChainTracer, we don't need to add another one.
|
||||
# this would likely mess up the trace hierarchy.
|
||||
cb = callback_manager
|
||||
else:
|
||||
cb = None
|
||||
return cb
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_as_chain_group(
|
||||
group_name: str,
|
||||
@ -241,6 +272,8 @@ def trace_as_chain_group(
|
||||
tags (List[str], optional): The inheritable tags to apply to all runs.
|
||||
Defaults to None.
|
||||
|
||||
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
|
||||
|
||||
Returns:
|
||||
CallbackManagerForChainGroup: The callback manager for the chain group.
|
||||
|
||||
@ -253,16 +286,8 @@ def trace_as_chain_group(
|
||||
res = llm.predict(llm_input, callbacks=manager)
|
||||
manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
LangChainTracer(
|
||||
project_name=project_name,
|
||||
example_id=example_id,
|
||||
)
|
||||
]
|
||||
if callback_manager is None
|
||||
else callback_manager,
|
||||
cb = _get_trace_callbacks(
|
||||
project_name, example_id, callback_manager=callback_manager
|
||||
)
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=cb,
|
||||
@ -321,6 +346,8 @@ async def atrace_as_chain_group(
|
||||
Returns:
|
||||
AsyncCallbackManager: The async callback manager for the chain group.
|
||||
|
||||
Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
@ -330,16 +357,8 @@ async def atrace_as_chain_group(
|
||||
res = await llm.apredict(llm_input, callbacks=manager)
|
||||
await manager.on_chain_end({"output": res})
|
||||
""" # noqa: E501
|
||||
cb = cast(
|
||||
Callbacks,
|
||||
[
|
||||
LangChainTracer(
|
||||
project_name=project_name,
|
||||
example_id=example_id,
|
||||
)
|
||||
]
|
||||
if callback_manager is None
|
||||
else callback_manager,
|
||||
cb = _get_trace_callbacks(
|
||||
project_name, example_id, callback_manager=callback_manager
|
||||
)
|
||||
cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags)
|
||||
|
||||
@ -1895,6 +1914,32 @@ def env_var_is_set(env_var: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _tracing_v2_is_enabled() -> bool:
|
||||
return (
|
||||
env_var_is_set("LANGCHAIN_TRACING_V2")
|
||||
or tracing_v2_callback_var.get() is not None
|
||||
or get_run_tree_context() is not None
|
||||
)
|
||||
|
||||
|
||||
def _get_tracer_project() -> str:
|
||||
run_tree = get_run_tree_context()
|
||||
return getattr(
|
||||
run_tree,
|
||||
"session_name",
|
||||
getattr(
|
||||
# Note, if people are trying to nest @traceable functions and the
|
||||
# tracing_v2_enabled context manager, this will likely mess up the
|
||||
# tree structure.
|
||||
tracing_v2_callback_var.get(),
|
||||
"project",
|
||||
os.environ.get(
|
||||
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _configure(
|
||||
callback_manager_cls: Type[T],
|
||||
inheritable_callbacks: Callbacks = None,
|
||||
@ -1973,18 +2018,8 @@ def _configure(
|
||||
)
|
||||
|
||||
tracer_v2 = tracing_v2_callback_var.get()
|
||||
tracing_v2_enabled_ = (
|
||||
env_var_is_set("LANGCHAIN_TRACING_V2")
|
||||
or tracer_v2 is not None
|
||||
or run_tree is not None
|
||||
)
|
||||
tracer_project = getattr(
|
||||
run_tree,
|
||||
"session_name",
|
||||
os.environ.get(
|
||||
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
|
||||
),
|
||||
)
|
||||
tracing_v2_enabled_ = _tracing_v2_is_enabled()
|
||||
tracer_project = _get_tracer_project()
|
||||
run_collector_ = run_collector_var.get()
|
||||
debug = _get_debug()
|
||||
if (
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Test CallbackManager."""
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -9,9 +10,10 @@ from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
get_openai_callback,
|
||||
trace_as_chain_group,
|
||||
tracing_v2_enabled,
|
||||
)
|
||||
from langchain.callbacks.stdout import StdOutCallbackHandler
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
||||
from langchain.llms.openai import BaseOpenAI
|
||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import (
|
||||
@ -303,70 +305,104 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def test_callback_manager_configure_context_vars(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test callback manager configuration."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true")
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||
with patch.object(LangChainTracer, "_update_run_single"):
|
||||
with patch.object(LangChainTracer, "_persist_run_single"):
|
||||
with trace_as_chain_group("test") as group_manager:
|
||||
assert len(group_manager.handlers) == 1
|
||||
tracer = group_manager.handlers[0]
|
||||
assert isinstance(tracer, LangChainTracer)
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
wait_for_all_tracers()
|
||||
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
|
||||
|
||||
|
||||
def test_trace_as_chain_group_within_tracing_v2_context_manager(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test callback manager configuration."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||
with tracing_v2_enabled():
|
||||
with patch.object(LangChainTracer, "_update_run_single"):
|
||||
with patch.object(LangChainTracer, "_persist_run_single"):
|
||||
with trace_as_chain_group("test") as group_manager:
|
||||
assert len(group_manager.handlers) == 1
|
||||
tracer = group_manager.handlers[0]
|
||||
assert isinstance(tracer, LangChainTracer)
|
||||
wait_for_all_tracers()
|
||||
assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore
|
||||
|
||||
with trace_as_chain_group("test") as group_manager:
|
||||
assert len(group_manager.handlers) == 1
|
||||
tracer = group_manager.handlers[0]
|
||||
assert isinstance(tracer, LangChainTracer)
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
|
||||
with get_openai_callback() as cb:
|
||||
# This is a new empty callback handler
|
||||
assert cb.successful_requests == 0
|
||||
assert cb.total_tokens == 0
|
||||
|
||||
# configure adds this openai cb but doesn't modify the group manager
|
||||
mngr = CallbackManager.configure(group_manager)
|
||||
assert mngr.handlers == [tracer, cb]
|
||||
assert group_manager.handlers == [tracer]
|
||||
|
||||
response = LLMResult(
|
||||
generations=[],
|
||||
llm_output={
|
||||
"token_usage": {
|
||||
"prompt_tokens": 2,
|
||||
"completion_tokens": 1,
|
||||
"total_tokens": 3,
|
||||
},
|
||||
"model_name": BaseOpenAI.__fields__["model_name"].default,
|
||||
},
|
||||
)
|
||||
mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response)
|
||||
|
||||
# The callback handler has been updated
|
||||
assert cb.successful_requests == 1
|
||||
assert cb.total_tokens == 3
|
||||
assert cb.prompt_tokens == 2
|
||||
assert cb.completion_tokens == 1
|
||||
assert cb.total_cost > 0
|
||||
def test_trace_as_chain_group_tracing_disabled(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Test callback manager configuration."""
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false")
|
||||
monkeypatch.setenv("LANGCHAIN_TRACING", "false")
|
||||
with patch.object(LangChainTracer, "_update_run_single"):
|
||||
with patch.object(LangChainTracer, "_persist_run_single"):
|
||||
with trace_as_chain_group("test") as group_manager:
|
||||
assert len(group_manager.handlers) == 0
|
||||
assert LangChainTracer._persist_run_single.call_count == 0 # type: ignore
|
||||
|
Loading…
Reference in New Issue
Block a user