[Core] Tracing: update parent run_tree's child_runs (#21049)

This commit is contained in:
William FH
2024-05-01 06:33:08 -07:00
committed by GitHub
parent 86fe484e24
commit ab55f6996d
4 changed files with 75 additions and 9 deletions

View File

@@ -1,16 +1,21 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, List
from unittest.mock import MagicMock
from uuid import uuid4
import langsmith
import pytest
from freezegun import freeze_time
from langsmith import Client, traceable
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import HumanMessage
from langchain_core.outputs import LLMResult
from langchain_core.runnables import chain as as_runnable
from langchain_core.tracers.base import BaseTracer, TracerException
from langchain_core.tracers.schemas import Run
@@ -627,3 +632,33 @@ def test_tracer_nested_runs_on_error() -> None:
assert len(tracer.runs) == 3
for run in tracer.runs:
_compare_run_with_error(run, compare_run)
def _get_mock_client() -> Client:
mock_session = MagicMock()
client = Client(session=mock_session, api_key="test")
return client
def test_traceable_to_tracing() -> None:
has_children = False
def _collect_run(run: Any) -> None:
nonlocal has_children
has_children = bool(run.child_runs)
@as_runnable
def foo(x: int) -> int:
return x + 1
@traceable
def some_parent(a: int, b: int) -> int:
return foo.invoke(a) + foo.invoke(b)
mock_client_ = _get_mock_client()
with langsmith.run_helpers.tracing_context(enabled=True):
result = some_parent(
1, 2, langsmith_extra={"client": mock_client_, "on_end": _collect_run}
)
assert result == 5
assert has_children, "Child run not collected"

View File

@@ -2,11 +2,13 @@ import threading
import time
import unittest
import unittest.mock
import uuid
from typing import Any, Dict
from uuid import UUID
import pytest
from langsmith import Client
from langsmith.run_trees import RunTree
from langchain_core.outputs import LLMResult
from langchain_core.tracers.langchain import LangChainTracer
@@ -59,6 +61,24 @@ def test_example_id_assignment_threadsafe() -> None:
assert example_ids == expected_example_ids
def test_tracer_with_run_tree_parent() -> None:
mock_session = unittest.mock.MagicMock()
client = Client(session=mock_session, api_key="test")
parent = RunTree(name="parent", inputs={"input": "foo"}, client=client)
run_id = uuid.uuid4()
tracer = LangChainTracer(client=client)
tracer.run_map[str(parent.id)] = parent # type: ignore
tracer.on_chain_start(
{"name": "child"}, {"input": "bar"}, run_id=run_id, parent_run_id=parent.id
)
tracer.on_chain_end({}, run_id=run_id)
assert parent.child_runs
assert len(parent.child_runs) == 1
assert parent.child_runs[0].id == run_id
assert parent.child_runs[0].trace_id == parent.id
assert parent.child_runs[0].parent_run_id == parent.id
def test_log_lock() -> None:
"""Test that example assigned at callback start/end is honored."""