mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-01 02:50:47 +00:00
[Core] Tracing: update parent run_tree's child_runs (#21049)
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
"""Test Tracer classes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, List
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import langsmith
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from langsmith import Client, traceable
|
||||
|
||||
from langchain_core.callbacks import CallbackManager
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.runnables import chain as as_runnable
|
||||
from langchain_core.tracers.base import BaseTracer, TracerException
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@@ -627,3 +632,33 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
assert len(tracer.runs) == 3
|
||||
for run in tracer.runs:
|
||||
_compare_run_with_error(run, compare_run)
|
||||
|
||||
|
||||
def _get_mock_client() -> Client:
|
||||
mock_session = MagicMock()
|
||||
client = Client(session=mock_session, api_key="test")
|
||||
return client
|
||||
|
||||
|
||||
def test_traceable_to_tracing() -> None:
|
||||
has_children = False
|
||||
|
||||
def _collect_run(run: Any) -> None:
|
||||
nonlocal has_children
|
||||
has_children = bool(run.child_runs)
|
||||
|
||||
@as_runnable
|
||||
def foo(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@traceable
|
||||
def some_parent(a: int, b: int) -> int:
|
||||
return foo.invoke(a) + foo.invoke(b)
|
||||
|
||||
mock_client_ = _get_mock_client()
|
||||
with langsmith.run_helpers.tracing_context(enabled=True):
|
||||
result = some_parent(
|
||||
1, 2, langsmith_extra={"client": mock_client_, "on_end": _collect_run}
|
||||
)
|
||||
assert result == 5
|
||||
assert has_children, "Child run not collected"
|
||||
|
@@ -2,11 +2,13 @@ import threading
|
||||
import time
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import uuid
|
||||
from typing import Any, Dict
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from langsmith import Client
|
||||
from langsmith.run_trees import RunTree
|
||||
|
||||
from langchain_core.outputs import LLMResult
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
@@ -59,6 +61,24 @@ def test_example_id_assignment_threadsafe() -> None:
|
||||
assert example_ids == expected_example_ids
|
||||
|
||||
|
||||
def test_tracer_with_run_tree_parent() -> None:
|
||||
mock_session = unittest.mock.MagicMock()
|
||||
client = Client(session=mock_session, api_key="test")
|
||||
parent = RunTree(name="parent", inputs={"input": "foo"}, client=client)
|
||||
run_id = uuid.uuid4()
|
||||
tracer = LangChainTracer(client=client)
|
||||
tracer.run_map[str(parent.id)] = parent # type: ignore
|
||||
tracer.on_chain_start(
|
||||
{"name": "child"}, {"input": "bar"}, run_id=run_id, parent_run_id=parent.id
|
||||
)
|
||||
tracer.on_chain_end({}, run_id=run_id)
|
||||
assert parent.child_runs
|
||||
assert len(parent.child_runs) == 1
|
||||
assert parent.child_runs[0].id == run_id
|
||||
assert parent.child_runs[0].trace_id == parent.id
|
||||
assert parent.child_runs[0].parent_run_id == parent.id
|
||||
|
||||
|
||||
def test_log_lock() -> None:
|
||||
"""Test that example assigned at callback start/end is honored."""
|
||||
|
||||
|
Reference in New Issue
Block a user