diff --git a/libs/langchain/langchain/callbacks/manager.py b/libs/langchain/langchain/callbacks/manager.py index da8b73d93be..8c19ab589e6 100644 --- a/libs/langchain/langchain/callbacks/manager.py +++ b/libs/langchain/langchain/callbacks/manager.py @@ -124,9 +124,11 @@ def tracing_enabled( """ cb = LangChainTracerV1() session = cast(TracerSessionV1, cb.load_session(session_name)) - tracing_callback_var.set(cb) - yield session - tracing_callback_var.set(None) + try: + tracing_callback_var.set(cb) + yield session + finally: + tracing_callback_var.set(None) @contextmanager @@ -191,9 +193,11 @@ def tracing_v2_enabled( tags=tags, client=client, ) - tracing_v2_callback_var.set(cb) - yield cb - tracing_v2_callback_var.set(None) + try: + tracing_v2_callback_var.set(cb) + yield cb + finally: + tracing_v2_callback_var.set(None) @contextmanager @@ -214,6 +218,33 @@ def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, run_collector_var.set(None) +def _get_trace_callbacks( + project_name: Optional[str] = None, + example_id: Optional[Union[str, UUID]] = None, + callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None, +) -> Callbacks: + if _tracing_v2_is_enabled(): + project_name_ = project_name or _get_tracer_project() + tracer = tracing_v2_callback_var.get() or LangChainTracer( + project_name=project_name_, + example_id=example_id, + ) + if callback_manager is None: + cb = cast(Callbacks, [tracer]) + else: + if not any( + isinstance(handler, LangChainTracer) + for handler in callback_manager.handlers + ): + callback_manager.add_handler(tracer, True) + # If it already has a LangChainTracer, we don't need to add another one. + # this would likely mess up the trace hierarchy. + cb = callback_manager + else: + cb = None + return cb + + @contextmanager def trace_as_chain_group( group_name: str, @@ -241,6 +272,8 @@ def trace_as_chain_group( tags (List[str], optional): The inheritable tags to apply to all runs. Defaults to None. + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + Returns: CallbackManagerForChainGroup: The callback manager for the chain group. @@ -253,16 +286,8 @@ def trace_as_chain_group( res = llm.predict(llm_input, callbacks=manager) manager.on_chain_end({"output": res}) """ # noqa: E501 - cb = cast( - Callbacks, - [ - LangChainTracer( - project_name=project_name, - example_id=example_id, - ) - ] - if callback_manager is None - else callback_manager, + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager ) cm = CallbackManager.configure( inheritable_callbacks=cb, @@ -321,6 +346,8 @@ async def atrace_as_chain_group( Returns: AsyncCallbackManager: The async callback manager for the chain group. + Note: must have LANGCHAIN_TRACING_V2 env var set to true to see the trace in LangSmith. + Example: .. code-block:: python @@ -330,16 +357,8 @@ async def atrace_as_chain_group( res = await llm.apredict(llm_input, callbacks=manager) await manager.on_chain_end({"output": res}) """ # noqa: E501 - cb = cast( - Callbacks, - [ - LangChainTracer( - project_name=project_name, - example_id=example_id, - ) - ] - if callback_manager is None - else callback_manager, + cb = _get_trace_callbacks( + project_name, example_id, callback_manager=callback_manager ) cm = AsyncCallbackManager.configure(inheritable_callbacks=cb, inheritable_tags=tags) @@ -1895,6 +1914,32 @@ def env_var_is_set(env_var: str) -> bool: ) +def _tracing_v2_is_enabled() -> bool: + return ( + env_var_is_set("LANGCHAIN_TRACING_V2") + or tracing_v2_callback_var.get() is not None + or get_run_tree_context() is not None + ) + + +def _get_tracer_project() -> str: + run_tree = get_run_tree_context() + return getattr( + run_tree, + "session_name", + getattr( + # Note, if people are trying to nest @traceable functions and the + # tracing_v2_enabled context manager, this will likely mess up the + # tree structure. + tracing_v2_callback_var.get(), + "project", + os.environ.get( + "LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default") + ), + ), + ) + + def _configure( callback_manager_cls: Type[T], inheritable_callbacks: Callbacks = None, @@ -1973,18 +2018,8 @@ def _configure( ) tracer_v2 = tracing_v2_callback_var.get() - tracing_v2_enabled_ = ( - env_var_is_set("LANGCHAIN_TRACING_V2") - or tracer_v2 is not None - or run_tree is not None - ) - tracer_project = getattr( - run_tree, - "session_name", - os.environ.get( - "LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default") - ), - ) + tracing_v2_enabled_ = _tracing_v2_is_enabled() + tracer_project = _get_tracer_project() run_collector_ = run_collector_var.get() debug = _get_debug() if ( diff --git a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py index fdbea1e9122..32670f3c72d 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_callback_manager.py @@ -1,5 +1,6 @@ """Test CallbackManager.""" from typing import List, Tuple +from unittest.mock import patch import pytest @@ -9,9 +10,10 @@ from langchain.callbacks.manager import ( CallbackManager, get_openai_callback, trace_as_chain_group, + tracing_v2_enabled, ) from langchain.callbacks.stdout import StdOutCallbackHandler -from langchain.callbacks.tracers.langchain import LangChainTracer +from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers from langchain.llms.openai import BaseOpenAI from langchain.schema import AgentAction, AgentFinish, LLMResult from tests.unit_tests.callbacks.fake_callback_handler import ( @@ -303,70 +305,104 @@ def test_callback_manager_configure(monkeypatch: pytest.MonkeyPatch) -> None: def test_callback_manager_configure_context_vars( monkeypatch: pytest.MonkeyPatch, ) -> None: + """Test callback manager configuration.""" + monkeypatch.setenv("LANGCHAIN_TRACING_V2", "true") + monkeypatch.setenv("LANGCHAIN_TRACING", "false") + with patch.object(LangChainTracer, "_update_run_single"): + with patch.object(LangChainTracer, "_persist_run_single"): + with trace_as_chain_group("test") as group_manager: + assert len(group_manager.handlers) == 1 + tracer = group_manager.handlers[0] + assert isinstance(tracer, LangChainTracer) + + with get_openai_callback() as cb: + # This is a new empty callback handler + assert cb.successful_requests == 0 + assert cb.total_tokens == 0 + + # configure adds this openai cb but doesn't modify the group manager + mngr = CallbackManager.configure(group_manager) + assert mngr.handlers == [tracer, cb] + assert group_manager.handlers == [tracer] + + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": BaseOpenAI.__fields__["model_name"].default, + }, + ) + mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + + # The callback handler has been updated + assert cb.successful_requests == 1 + assert cb.total_tokens == 3 + assert cb.prompt_tokens == 2 + assert cb.completion_tokens == 1 + assert cb.total_cost > 0 + + with get_openai_callback() as cb: + # This is a new empty callback handler + assert cb.successful_requests == 0 + assert cb.total_tokens == 0 + + # configure adds this openai cb but doesn't modify the group manager + mngr = CallbackManager.configure(group_manager) + assert mngr.handlers == [tracer, cb] + assert group_manager.handlers == [tracer] + + response = LLMResult( + generations=[], + llm_output={ + "token_usage": { + "prompt_tokens": 2, + "completion_tokens": 1, + "total_tokens": 3, + }, + "model_name": BaseOpenAI.__fields__["model_name"].default, + }, + ) + mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) + + # The callback handler has been updated + assert cb.successful_requests == 1 + assert cb.total_tokens == 3 + assert cb.prompt_tokens == 2 + assert cb.completion_tokens == 1 + assert cb.total_cost > 0 + wait_for_all_tracers() + assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore + + +def test_trace_as_chain_group_within_tracing_v2_context_manager( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test callback manager configuration.""" monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false") monkeypatch.setenv("LANGCHAIN_TRACING", "false") + with tracing_v2_enabled(): + with patch.object(LangChainTracer, "_update_run_single"): + with patch.object(LangChainTracer, "_persist_run_single"): + with trace_as_chain_group("test") as group_manager: + assert len(group_manager.handlers) == 1 + tracer = group_manager.handlers[0] + assert isinstance(tracer, LangChainTracer) + wait_for_all_tracers() + assert LangChainTracer._persist_run_single.call_count == 1 # type: ignore - with trace_as_chain_group("test") as group_manager: - assert len(group_manager.handlers) == 1 - tracer = group_manager.handlers[0] - assert isinstance(tracer, LangChainTracer) - with get_openai_callback() as cb: - # This is a new empty callback handler - assert cb.successful_requests == 0 - assert cb.total_tokens == 0 - - # configure adds this openai cb but doesn't modify the group manager - mngr = CallbackManager.configure(group_manager) - assert mngr.handlers == [tracer, cb] - assert group_manager.handlers == [tracer] - - response = LLMResult( - generations=[], - llm_output={ - "token_usage": { - "prompt_tokens": 2, - "completion_tokens": 1, - "total_tokens": 3, - }, - "model_name": BaseOpenAI.__fields__["model_name"].default, - }, - ) - mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) - - # The callback handler has been updated - assert cb.successful_requests == 1 - assert cb.total_tokens == 3 - assert cb.prompt_tokens == 2 - assert cb.completion_tokens == 1 - assert cb.total_cost > 0 - - with get_openai_callback() as cb: - # This is a new empty callback handler - assert cb.successful_requests == 0 - assert cb.total_tokens == 0 - - # configure adds this openai cb but doesn't modify the group manager - mngr = CallbackManager.configure(group_manager) - assert mngr.handlers == [tracer, cb] - assert group_manager.handlers == [tracer] - - response = LLMResult( - generations=[], - llm_output={ - "token_usage": { - "prompt_tokens": 2, - "completion_tokens": 1, - "total_tokens": 3, - }, - "model_name": BaseOpenAI.__fields__["model_name"].default, - }, - ) - mngr.on_llm_start({}, ["prompt"])[0].on_llm_end(response) - - # The callback handler has been updated - assert cb.successful_requests == 1 - assert cb.total_tokens == 3 - assert cb.prompt_tokens == 2 - assert cb.completion_tokens == 1 - assert cb.total_cost > 0 +def test_trace_as_chain_group_tracing_disabled( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test callback manager configuration.""" + monkeypatch.setenv("LANGCHAIN_TRACING_V2", "false") + monkeypatch.setenv("LANGCHAIN_TRACING", "false") + with patch.object(LangChainTracer, "_update_run_single"): + with patch.object(LangChainTracer, "_persist_run_single"): + with trace_as_chain_group("test") as group_manager: + assert len(group_manager.handlers) == 0 + assert LangChainTracer._persist_run_single.call_count == 0 # type: ignore