[Core] Add more interops tests (#26841)

To test that the client propagates both ways
This commit is contained in:
William FH
2024-09-24 20:18:20 -07:00
committed by GitHub
parent 9b6ac41442
commit 82b5b77940
5 changed files with 36 additions and 18 deletions

View File

@@ -10,6 +10,7 @@ from langsmith import Client, get_current_run_tree, traceable
from langsmith.run_helpers import tracing_context
from langsmith.run_trees import RunTree
from langsmith.utils import get_env_var
from typing_extensions import Literal
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer
@@ -361,7 +362,8 @@ async def test_runnable_sequence_parallel_trace_nesting(method: str) -> None:
assert dotted_order.split(".")[0] == dotted_order
def test_tree_is_constructed() -> None:
@pytest.mark.parametrize("parent_type", ("ls", "lc"))
def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
mock_session = MagicMock()
mock_client_ = Client(
session=mock_session, api_key="test", auto_batch_tracing=False
@@ -379,27 +381,39 @@ def test_tree_is_constructed() -> None:
def child(x: str) -> str:
return grandchild.invoke(x)
@traceable
def parent() -> str:
return child.invoke("foo")
collected: dict[str, RunTree] = {} # noqa
def collect_run(run: RunTree) -> None:
collected[str(run.id)] = run
rid = uuid.uuid4()
with tracing_context(
client=mock_client_,
enabled=True,
metadata={"some_foo": "some_bar"},
tags=["afoo"],
):
assert parent(langsmith_extra={"on_end": collect_run, "run_id": rid}) == "foo"
if parent_type == "ls":
collected: dict[str, RunTree] = {} # noqa
def collect_run(run: RunTree) -> None:
collected[str(run.id)] = run
@traceable
def parent() -> str:
return child.invoke("foo")
assert (
parent(langsmith_extra={"on_end": collect_run, "run_id": rid}) == "foo"
)
assert collected
run = collected.get(str(rid))
else:
@RunnableLambda
def parent(_) -> str: # type: ignore
return child.invoke("foo")
tracer = LangChainTracer()
assert parent.invoke(..., {"run_id": rid, "callbacks": [tracer]}) == "foo" # type: ignore
run = tracer.latest_run
assert collected
run = collected.get(str(rid))
assert run is not None
assert run.name == "parent"
assert run.child_runs