[Core] Tracing: update parent run_tree's child_runs (#21049)

This commit is contained in:
William FH 2024-05-01 06:33:08 -07:00 committed by GitHub
parent 86fe484e24
commit ab55f6996d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 75 additions and 9 deletions

View File

@ -41,6 +41,7 @@ from langchain_core.callbacks.base import (
) )
from langchain_core.callbacks.stdout import StdOutCallbackHandler from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.tracers.schemas import Run
from langchain_core.utils.env import env_var_is_set from langchain_core.utils.env import env_var_is_set
if TYPE_CHECKING: if TYPE_CHECKING:
@ -1994,15 +1995,22 @@ def _configure(
callback_manager.add_handler(tracer_v2, True) callback_manager.add_handler(tracer_v2, True)
else: else:
try: try:
handler = LangChainTracer(project_name=tracer_project) handler = LangChainTracer(
project_name=tracer_project,
client=run_tree.client if run_tree is not None else None,
)
callback_manager.add_handler(handler, True) callback_manager.add_handler(handler, True)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Unable to load requested LangChainTracer." "Unable to load requested LangChainTracer."
" To disable this warning," " To disable this warning,"
" unset the LANGCHAIN_TRACING_V2 environment variables.", " unset the LANGCHAIN_TRACING_V2 environment variables.",
e, f"{repr(e)}",
) )
if run_tree is not None:
for handler in callback_manager.handlers:
if isinstance(handler, LangChainTracer):
handler.run_map[str(run_tree.id)] = cast(Run, run_tree)
for var, inheritable, handler_class, env_var in _configure_hooks: for var, inheritable, handler_class, env_var in _configure_hooks:
create_one = ( create_one = (
env_var is not None env_var is not None

View File

@ -1,4 +1,5 @@
"""Base interfaces for tracing runs.""" """Base interfaces for tracing runs."""
from __future__ import annotations from __future__ import annotations
import logging import logging
@ -102,9 +103,10 @@ class BaseTracer(BaseCallbackHandler, ABC):
parent_run = self.run_map.get(str(run.parent_run_id)) parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run: if parent_run:
self._add_child_run(parent_run, run) self._add_child_run(parent_run, run)
parent_run.child_execution_order = max( if hasattr(parent_run, "child_execution_order"):
parent_run.child_execution_order, run.child_execution_order parent_run.child_execution_order = max(
) parent_run.child_execution_order, run.child_execution_order
)
run.trace_id = parent_run.trace_id run.trace_id = parent_run.trace_id
if parent_run.dotted_order: if parent_run.dotted_order:
run.dotted_order = ( run.dotted_order = (
@ -135,7 +137,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
logger.debug(f"Parent run with UUID {run.parent_run_id} not found.") logger.debug(f"Parent run with UUID {run.parent_run_id} not found.")
elif ( elif (
run.child_execution_order is not None run.child_execution_order is not None
and parent_run.child_execution_order is not None and getattr(parent_run, "child_execution_order", None) is not None
and run.child_execution_order > parent_run.child_execution_order and run.child_execution_order > parent_run.child_execution_order
): ):
parent_run.child_execution_order = run.child_execution_order parent_run.child_execution_order = run.child_execution_order
@ -151,10 +153,11 @@ class BaseTracer(BaseCallbackHandler, ABC):
if parent_run is None: if parent_run is None:
logger.debug(f"Parent run with UUID {parent_run_id} not found.") logger.debug(f"Parent run with UUID {parent_run_id} not found.")
return 1 return 1
if parent_run.child_execution_order is None: if getattr(parent_run, "child_execution_order", None) is None:
raise TracerException( logger.debug(
f"Parent run with UUID {parent_run_id} has no child execution order." f"Parent run with UUID {parent_run_id} has no child_execution_order."
) )
return 1
return parent_run.child_execution_order + 1 return parent_run.child_execution_order + 1

View File

@ -1,16 +1,21 @@
"""Test Tracer classes.""" """Test Tracer classes."""
from __future__ import annotations from __future__ import annotations
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, List from typing import Any, List
from unittest.mock import MagicMock
from uuid import uuid4 from uuid import uuid4
import langsmith
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from langsmith import Client, traceable
from langchain_core.callbacks import CallbackManager from langchain_core.callbacks import CallbackManager
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_core.runnables import chain as as_runnable
from langchain_core.tracers.base import BaseTracer, TracerException from langchain_core.tracers.base import BaseTracer, TracerException
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@ -627,3 +632,33 @@ def test_tracer_nested_runs_on_error() -> None:
assert len(tracer.runs) == 3 assert len(tracer.runs) == 3
for run in tracer.runs: for run in tracer.runs:
_compare_run_with_error(run, compare_run) _compare_run_with_error(run, compare_run)
def _get_mock_client() -> Client:
mock_session = MagicMock()
client = Client(session=mock_session, api_key="test")
return client
def test_traceable_to_tracing() -> None:
has_children = False
def _collect_run(run: Any) -> None:
nonlocal has_children
has_children = bool(run.child_runs)
@as_runnable
def foo(x: int) -> int:
return x + 1
@traceable
def some_parent(a: int, b: int) -> int:
return foo.invoke(a) + foo.invoke(b)
mock_client_ = _get_mock_client()
with langsmith.run_helpers.tracing_context(enabled=True):
result = some_parent(
1, 2, langsmith_extra={"client": mock_client_, "on_end": _collect_run}
)
assert result == 5
assert has_children, "Child run not collected"

View File

@ -2,11 +2,13 @@ import threading
import time import time
import unittest import unittest
import unittest.mock import unittest.mock
import uuid
from typing import Any, Dict from typing import Any, Dict
from uuid import UUID from uuid import UUID
import pytest import pytest
from langsmith import Client from langsmith import Client
from langsmith.run_trees import RunTree
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.langchain import LangChainTracer
@ -59,6 +61,24 @@ def test_example_id_assignment_threadsafe() -> None:
assert example_ids == expected_example_ids assert example_ids == expected_example_ids
def test_tracer_with_run_tree_parent() -> None:
mock_session = unittest.mock.MagicMock()
client = Client(session=mock_session, api_key="test")
parent = RunTree(name="parent", inputs={"input": "foo"}, client=client)
run_id = uuid.uuid4()
tracer = LangChainTracer(client=client)
tracer.run_map[str(parent.id)] = parent # type: ignore
tracer.on_chain_start(
{"name": "child"}, {"input": "bar"}, run_id=run_id, parent_run_id=parent.id
)
tracer.on_chain_end({}, run_id=run_id)
assert parent.child_runs
assert len(parent.child_runs) == 1
assert parent.child_runs[0].id == run_id
assert parent.child_runs[0].trace_id == parent.id
assert parent.child_runs[0].parent_run_id == parent.id
def test_log_lock() -> None: def test_log_lock() -> None:
"""Test that example assigned at callback start/end is honored.""" """Test that example assigned at callback start/end is honored."""