mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
chore(core): reduce streaming metadata / perf (#36588)
- looking into reducing streaming metadata / perfm --------- Co-authored-by: William Fu-Hinthorn <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
|
||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_get_langsmith_inheritable_metadata_from_config,
|
||||
_set_config_context,
|
||||
ensure_config,
|
||||
merge_configs,
|
||||
@@ -61,7 +62,7 @@ def test_ensure_config() -> None:
|
||||
assert config["configurable"] is not arg["configurable"]
|
||||
assert config == {
|
||||
"tags": ["tag1", "tag2"],
|
||||
"metadata": {"foo": "bar", "baz": "qux", "something": "else"},
|
||||
"metadata": {"foo": "bar"},
|
||||
"callbacks": [arg["callbacks"][0]],
|
||||
"recursion_limit": 100,
|
||||
"configurable": {"baz": "qux", "something": "else"},
|
||||
@@ -71,6 +72,145 @@ def test_ensure_config() -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_ensure_config_copies_model_to_metadata() -> None:
|
||||
config = ensure_config(
|
||||
{
|
||||
"configurable": {
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"checkpoint_ns": "ns-1",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"model": "gpt-4o",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"some_api_key": "opaque-token",
|
||||
"custom_setting": {"nested": True},
|
||||
"none_value": None,
|
||||
},
|
||||
"metadata": {"nooverride": 18},
|
||||
}
|
||||
)
|
||||
|
||||
assert config["metadata"] == {"nooverride": 18, "model": "gpt-4o"}
|
||||
assert config["configurable"] == {
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"checkpoint_ns": "ns-1",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"model": "gpt-4o",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"some_api_key": "opaque-token",
|
||||
"custom_setting": {"nested": True},
|
||||
"none_value": None,
|
||||
}
|
||||
|
||||
|
||||
def test_ensure_config_metadata_is_not_overridden_by_configurable_model() -> None:
|
||||
config = ensure_config(
|
||||
{
|
||||
"configurable": {
|
||||
"model": "from-configurable",
|
||||
"run_id": None,
|
||||
"checkpoint_ns": "from-configurable",
|
||||
},
|
||||
"metadata": {
|
||||
"model": "from-metadata",
|
||||
"run_id": "from-metadata",
|
||||
"checkpoint_ns": "from-metadata",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert config["metadata"] == {
|
||||
"model": "from-metadata",
|
||||
"run_id": "from-metadata",
|
||||
"checkpoint_ns": "from-metadata",
|
||||
}
|
||||
assert config["configurable"] == {
|
||||
"model": "from-configurable",
|
||||
"run_id": None,
|
||||
"checkpoint_ns": "from-configurable",
|
||||
}
|
||||
|
||||
|
||||
def test_ensure_config_copies_top_level_model_to_metadata() -> None:
|
||||
config = ensure_config(
|
||||
cast(
|
||||
"RunnableConfig",
|
||||
{
|
||||
"model": "gpt-4o",
|
||||
"metadata": {"nooverride": 18},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert config["metadata"] == {"nooverride": 18, "model": "gpt-4o"}
|
||||
assert config["configurable"] == {"model": "gpt-4o"}
|
||||
|
||||
|
||||
def test_get_langsmith_inheritable_metadata_from_config_uses_previous_copy_rules() -> (
|
||||
None
|
||||
):
|
||||
config = ensure_config(
|
||||
cast(
|
||||
"RunnableConfig",
|
||||
{
|
||||
"something": "else",
|
||||
"metadata": {
|
||||
"foo": "bar",
|
||||
"model": "from-metadata",
|
||||
"checkpoint_ns": "from-metadata",
|
||||
},
|
||||
"configurable": {
|
||||
"baz": "qux",
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"checkpoint_ns": "from-configurable",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"model": "from-configurable",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"api_key": "should-not-propagate",
|
||||
"__secret_key": "should-not-propagate",
|
||||
"temperature": 0.5,
|
||||
"streaming": True,
|
||||
"custom_setting": {"nested": True},
|
||||
"none_value": None,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
assert _get_langsmith_inheritable_metadata_from_config(config) == {
|
||||
"something": "else",
|
||||
"baz": "qux",
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"temperature": 0.5,
|
||||
"streaming": True,
|
||||
}
|
||||
|
||||
|
||||
async def test_merge_config_callbacks() -> None:
|
||||
manager: RunnableConfig = {
|
||||
"callbacks": CallbackManager(handlers=[StdOutCallbackHandler()])
|
||||
|
||||
@@ -1162,7 +1162,7 @@ async def test_with_config_metadata_passthrough(mocker: MockerFixture) -> None:
|
||||
"callbacks": None,
|
||||
"recursion_limit": 25,
|
||||
"configurable": {"hello": "there", "__secret_key": "nahnah"},
|
||||
"metadata": {"hello": "there", "bye": "now"},
|
||||
"metadata": {"bye": "now"},
|
||||
},
|
||||
)
|
||||
spy.reset_mock()
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from inspect import isasyncgenfunction
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
@@ -12,13 +15,15 @@ from langsmith import Client, RunTree, get_current_run_tree, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
from langsmith.utils import get_env_var
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator, Mapping
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list[dict[str, Any]]:
|
||||
@@ -43,12 +48,15 @@ def _get_posts(client: Client) -> list[dict[str, Any]]:
|
||||
def _create_tracer_with_mocked_client(
|
||||
project_name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
mock_session = MagicMock()
|
||||
mock_client_ = Client(
|
||||
session=mock_session, api_key="test", auto_batch_tracing=False
|
||||
)
|
||||
return LangChainTracer(client=mock_client_, project_name=project_name, tags=tags)
|
||||
return LangChainTracer(
|
||||
client=mock_client_, project_name=project_name, tags=tags, metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
def test_tracing_context() -> None:
|
||||
@@ -75,6 +83,38 @@ def test_tracing_context() -> None:
|
||||
assert all(post["session_name"] == project_name for post in posts)
|
||||
|
||||
|
||||
def test_inheritable_metadata_respects_explicit_metadata_with_tracing_context() -> None:
|
||||
"""Tracer defaults fill missing keys while run metadata keeps precedence."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={
|
||||
"tenant": "from_tracer",
|
||||
"shared": "from_tracer",
|
||||
},
|
||||
)
|
||||
with tracing_context(enabled=True, client=tracer.client):
|
||||
my_func.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": callbacks,
|
||||
"metadata": {"shared": "from_run", "explicit": "from_run"},
|
||||
},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
metadata = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert metadata["tenant"] == "from_tracer"
|
||||
assert metadata["shared"] == "from_run"
|
||||
assert metadata["explicit"] == "from_run"
|
||||
|
||||
|
||||
def test_config_traceable_handoff() -> None:
|
||||
if hasattr(get_env_var, "cache_clear"):
|
||||
get_env_var.cache_clear() # type: ignore[attr-defined]
|
||||
@@ -466,7 +506,10 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
):
|
||||
collected: dict[str, RunTree] = {}
|
||||
|
||||
def collect_run(run: RunTree) -> None:
|
||||
def collect_langsmith_run(run: RunTree) -> None:
|
||||
collected[str(run.id)] = run
|
||||
|
||||
def collect_tracer_run(_: LangChainTracer, run: RunTree) -> None:
|
||||
collected[str(run.id)] = run
|
||||
|
||||
if parent_type == "ls":
|
||||
@@ -476,7 +519,8 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
return child.invoke("foo")
|
||||
|
||||
assert (
|
||||
parent(langsmith_extra={"on_end": collect_run, "run_id": rid}) == "foo"
|
||||
parent(langsmith_extra={"on_end": collect_langsmith_run, "run_id": rid})
|
||||
== "foo"
|
||||
)
|
||||
assert collected
|
||||
|
||||
@@ -487,9 +531,10 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
return child.invoke("foo")
|
||||
|
||||
tracer = LangChainTracer()
|
||||
tracer._persist_run = collect_run # type: ignore[method-assign]
|
||||
|
||||
assert parent.invoke(..., {"run_id": rid, "callbacks": [tracer]}) == "foo" # type: ignore[attr-defined]
|
||||
with patch.object(LangChainTracer, "_persist_run", new=collect_tracer_run):
|
||||
assert (
|
||||
parent.invoke(..., {"run_id": rid, "callbacks": [tracer]}) == "foo" # type: ignore[attr-defined]
|
||||
)
|
||||
run = collected.get(str(rid))
|
||||
|
||||
assert run is not None
|
||||
@@ -508,3 +553,643 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
assert "afoo" in kitten_run.tags # type: ignore[operator]
|
||||
assert grandchild_run is not None
|
||||
assert kitten_run.dotted_order.startswith(grandchild_run.dotted_order)
|
||||
|
||||
|
||||
class TestTracerMetadataThroughInvoke:
|
||||
"""Tests for tracer metadata merging through invoke calls."""
|
||||
|
||||
def test_tracer_metadata_applied_to_all_runs(self) -> None:
|
||||
"""Tracer metadata appears on every run when no config metadata is set."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"env": "prod", "service": "api"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def child(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def parent(x: int) -> int:
|
||||
return child.invoke(x)
|
||||
|
||||
parent.invoke(1, {"callbacks": [tracer]})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 2
|
||||
for post in posts:
|
||||
md = post.get("extra", {}).get("metadata", {})
|
||||
assert md.get("env") == "prod", f"run {post['name']} missing env"
|
||||
assert md.get("service") == "api", f"run {post['name']} missing service"
|
||||
|
||||
def test_config_metadata_takes_precedence(self) -> None:
|
||||
"""Config metadata wins over tracer metadata for overlapping keys."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"env": "prod", "tracer_only": "yes"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer],
|
||||
"metadata": {"env": "staging", "config_only": "yes"},
|
||||
},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
# Config wins for overlapping key
|
||||
assert md["env"] == "staging"
|
||||
# Both non-overlapping keys are present
|
||||
assert md["tracer_only"] == "yes"
|
||||
assert md["config_only"] == "yes"
|
||||
|
||||
def test_nested_calls_inherit_config_metadata(self) -> None:
|
||||
"""Child runs inherit config metadata; tracer metadata fills gaps."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"tracer_key": "tracer_val"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def child(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def parent(x: int) -> int:
|
||||
return child.invoke(x)
|
||||
|
||||
parent.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer],
|
||||
"metadata": {"config_key": "config_val"},
|
||||
},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 2
|
||||
name_to_md = {
|
||||
post["name"]: post.get("extra", {}).get("metadata", {}) for post in posts
|
||||
}
|
||||
# Both parent and child should have config metadata (inherited)
|
||||
# and tracer metadata (patched in)
|
||||
for name, md in name_to_md.items():
|
||||
assert md.get("config_key") == "config_val", f"{name} missing config_key"
|
||||
assert md.get("tracer_key") == "tracer_val", f"{name} missing tracer_key"
|
||||
|
||||
def test_tracer_metadata_not_applied_to_sibling_handlers(self) -> None:
|
||||
"""Tracer metadata is not applied to other callback handlers.
|
||||
|
||||
`_patch_missing_metadata` copies the metadata dict before patching,
|
||||
so the callback manager's shared metadata dict is not mutated.
|
||||
Other handlers should only see config metadata, not tracer metadata.
|
||||
"""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"tracer_key": "tracer_val"}
|
||||
)
|
||||
|
||||
received_metadata: list[dict[str, Any]] = []
|
||||
|
||||
class MetadataCapture(BaseCallbackHandler):
|
||||
"""Callback handler that records metadata from chain events."""
|
||||
|
||||
def on_chain_start(self, *_args: Any, **kwargs: Any) -> None:
|
||||
received_metadata.append(dict(kwargs.get("metadata", {})))
|
||||
|
||||
capture = MetadataCapture()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer, capture],
|
||||
"metadata": {"shared_key": "shared_val"},
|
||||
},
|
||||
)
|
||||
|
||||
assert len(received_metadata) >= 1
|
||||
for md in received_metadata:
|
||||
assert md["shared_key"] == "shared_val"
|
||||
assert "tracer_key" not in md
|
||||
|
||||
# But the posted run DOES have tracer metadata
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) >= 1
|
||||
for post in posts:
|
||||
post_md = post.get("extra", {}).get("metadata", {})
|
||||
assert post_md["shared_key"] == "shared_val"
|
||||
assert post_md["tracer_key"] == "tracer_val"
|
||||
|
||||
def test_tracer_metadata_with_no_config_metadata(self) -> None:
|
||||
"""When no config metadata is set, tracer metadata is the sole source."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"only_from_tracer": "value"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(1, {"callbacks": [tracer]})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["only_from_tracer"] == "value"
|
||||
|
||||
def test_empty_tracer_metadata_does_not_interfere(self) -> None:
|
||||
"""Tracer with no metadata does not interfere with config metadata."""
|
||||
tracer = _create_tracer_with_mocked_client(metadata=None)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{"callbacks": [tracer], "metadata": {"config_key": "config_val"}},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["config_key"] == "config_val"
|
||||
|
||||
|
||||
def test_inheritable_metadata_nested_runs_preserve_parent_child_shape() -> None:
|
||||
"""Concurrent nested runs keep parent-child linkage within each invocation."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
@RunnableLambda
|
||||
def child(x: int) -> int:
|
||||
barrier.wait()
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def parent(x: int) -> int:
|
||||
return child.invoke(x)
|
||||
|
||||
def invoke_for_tenant(tenant: str, value: int) -> int:
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"tenant": tenant},
|
||||
)
|
||||
return parent.invoke(value, {"callbacks": callbacks})
|
||||
|
||||
threads = [
|
||||
threading.Thread(target=invoke_for_tenant, args=("alpha", 1)),
|
||||
threading.Thread(target=invoke_for_tenant, args=("beta", 2)),
|
||||
]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 4
|
||||
parents = [post for post in posts if post["name"] == "parent"]
|
||||
children = [post for post in posts if post["name"] == "child"]
|
||||
assert len(parents) == 2
|
||||
assert len(children) == 2
|
||||
parent_ids = {parent["id"] for parent in parents}
|
||||
assert {child["parent_run_id"] for child in children} == parent_ids
|
||||
assert {
|
||||
post.get("extra", {}).get("metadata", {}).get("tenant") for post in posts
|
||||
} == {
|
||||
"alpha",
|
||||
"beta",
|
||||
}
|
||||
|
||||
|
||||
def test_inheritable_metadata_parallel_children_keep_tenant_isolation() -> None:
|
||||
"""Concurrent roots with parallel child runs keep tenant metadata isolated."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
barrier = threading.Barrier(4)
|
||||
|
||||
@RunnableLambda
|
||||
def add_one(x: int) -> int:
|
||||
barrier.wait()
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def add_two(x: int) -> int:
|
||||
barrier.wait()
|
||||
return x + 2
|
||||
|
||||
parallel = RunnableParallel(first=add_one, second=add_two)
|
||||
|
||||
def invoke_for_tenant(tenant: str, value: int) -> dict[str, int]:
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"tenant": tenant},
|
||||
)
|
||||
return parallel.invoke(value, {"callbacks": callbacks})
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
list(executor.map(invoke_for_tenant, ["alpha", "beta"], [1, 2]))
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 6
|
||||
assert {
|
||||
post.get("extra", {}).get("metadata", {}).get("tenant") for post in posts
|
||||
} == {
|
||||
"alpha",
|
||||
"beta",
|
||||
}
|
||||
posts_by_trace: dict[str, list[dict[str, Any]]] = {}
|
||||
for post in posts:
|
||||
posts_by_trace.setdefault(post["trace_id"], []).append(post)
|
||||
assert len(posts_by_trace) == 2
|
||||
assert all(len(trace_posts) == 3 for trace_posts in posts_by_trace.values())
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 11), reason="Asyncio context vars require Python 3.11+"
|
||||
)
|
||||
async def test_langsmith_inheritable_metadata_mixed_sync_async_managers_isolated() -> (
|
||||
None
|
||||
):
|
||||
"""Sync and async manager configure paths can overlap without metadata sharing."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
@RunnableLambda
|
||||
async def async_runnable(x: int) -> int:
|
||||
await asyncio.sleep(0)
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def sync_runnable(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
async def run_sync() -> int:
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"path": "sync"},
|
||||
)
|
||||
return await asyncio.to_thread(
|
||||
sync_runnable.invoke, 1, {"callbacks": callbacks}
|
||||
)
|
||||
|
||||
async def run_async() -> int:
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"path": "async"},
|
||||
)
|
||||
return await async_runnable.ainvoke(1, {"callbacks": callbacks})
|
||||
|
||||
await asyncio.gather(run_sync(), run_async())
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 2
|
||||
assert {
|
||||
post.get("extra", {}).get("metadata", {}).get("path") for post in posts
|
||||
} == {
|
||||
"sync",
|
||||
"async",
|
||||
}
|
||||
|
||||
|
||||
class TestLangsmithInheritableTracingDefaultsInConfigure:
|
||||
"""Tests for LangSmith inheritable tracing defaults in configure."""
|
||||
|
||||
def test_langsmith_inheritable_metadata_applied_via_configure(self) -> None:
|
||||
"""langsmith_inheritable_metadata flows to a copied tracer."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"env": "prod", "service": "api"},
|
||||
)
|
||||
lc_tracers = [h for h in cm.handlers if isinstance(h, LangChainTracer)]
|
||||
assert len(lc_tracers) == 1
|
||||
assert lc_tracers[0] is not tracer
|
||||
assert lc_tracers[0].tracing_metadata == {"env": "prod", "service": "api"}
|
||||
assert tracer.tracing_metadata is None
|
||||
|
||||
def test_langsmith_inheritable_metadata_does_not_overwrite_tracer_metadata(
|
||||
self,
|
||||
) -> None:
|
||||
"""Tracer metadata takes precedence over langsmith_inheritable_metadata."""
|
||||
tracer = _create_tracer_with_mocked_client(metadata={"env": "staging"})
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"env": "prod", "service": "api"},
|
||||
)
|
||||
lc_tracer = next(h for h in cm.handlers if isinstance(h, LangChainTracer))
|
||||
assert tracer.tracing_metadata == {"env": "staging"}
|
||||
assert lc_tracer.tracing_metadata == {"env": "staging", "service": "api"}
|
||||
|
||||
def test_tracing_context_metadata_merged_into_langsmith_inheritable_metadata(
|
||||
self,
|
||||
) -> None:
|
||||
"""Tracing-context metadata merges into tracer defaults.
|
||||
|
||||
LangSmith metadata keeps precedence on collisions.
|
||||
"""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
with tracing_context(
|
||||
enabled=True,
|
||||
client=tracer.client,
|
||||
metadata={"trace_only": "value", "shared": "trace"},
|
||||
):
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={
|
||||
"shared": "langsmith",
|
||||
"tenant": "alpha",
|
||||
},
|
||||
)
|
||||
|
||||
lc_tracer = next(h for h in cm.handlers if isinstance(h, LangChainTracer))
|
||||
assert lc_tracer.tracing_metadata == {
|
||||
"trace_only": "value",
|
||||
"shared": "langsmith",
|
||||
"tenant": "alpha",
|
||||
}
|
||||
|
||||
def test_langsmith_inheritable_metadata_end_to_end(self) -> None:
|
||||
"""langsmith_inheritable_metadata in configure propagates to posted runs."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
# Use langsmith_inheritable_metadata through the config callbacks path
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"env": "prod"},
|
||||
)
|
||||
my_func.invoke(1, {"callbacks": cm})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["env"] == "prod"
|
||||
|
||||
def test_runnable_config_copies_configurable_values_to_tracing_metadata(
|
||||
self,
|
||||
) -> None:
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
config: RunnableConfig = {
|
||||
"callbacks": [tracer],
|
||||
"metadata": {
|
||||
"something": "else",
|
||||
"checkpoint_ns": "from-metadata",
|
||||
"model": "from-metadata",
|
||||
},
|
||||
"configurable": {
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"checkpoint_ns": "from-configurable",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"model": "from-configurable",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"api_key": "should-not-propagate",
|
||||
"__secret_key": "should-not-propagate",
|
||||
"temperature": 0.5,
|
||||
"streaming": True,
|
||||
"custom_setting": {"nested": True},
|
||||
"none_value": None,
|
||||
},
|
||||
}
|
||||
my_func.invoke(1, config)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert {
|
||||
key: md[key]
|
||||
for key in (
|
||||
"something",
|
||||
"thread_id",
|
||||
"checkpoint_id",
|
||||
"task_id",
|
||||
"run_id",
|
||||
"assistant_id",
|
||||
"graph_id",
|
||||
"user_id",
|
||||
"cron_id",
|
||||
"langgraph_auth_user_id",
|
||||
"temperature",
|
||||
"streaming",
|
||||
"model",
|
||||
"checkpoint_ns",
|
||||
)
|
||||
} == {
|
||||
"something": "else",
|
||||
"thread_id": "th-123",
|
||||
"checkpoint_id": "ckpt-1",
|
||||
"task_id": "task-1",
|
||||
"run_id": "run-456",
|
||||
"assistant_id": "asst-789",
|
||||
"graph_id": "graph-0",
|
||||
"user_id": "uid-1",
|
||||
"cron_id": "cron-1",
|
||||
"langgraph_auth_user_id": "user-1",
|
||||
"temperature": 0.5,
|
||||
"streaming": True,
|
||||
"model": "from-metadata",
|
||||
"checkpoint_ns": "from-metadata",
|
||||
}
|
||||
assert "api_key" not in md
|
||||
assert "__secret_key" not in md
|
||||
assert "custom_setting" not in md
|
||||
assert "none_value" not in md
|
||||
|
||||
def test_langsmith_inheritable_metadata_does_not_affect_non_tracer_handlers(
|
||||
self,
|
||||
) -> None:
|
||||
"""langsmith_inheritable_metadata only applies to tracers."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
received_metadata: list[dict[str, Any]] = []
|
||||
|
||||
class MetadataCapture(BaseCallbackHandler):
|
||||
def on_chain_start(self, *_args: Any, **kwargs: Any) -> None:
|
||||
received_metadata.append(dict(kwargs.get("metadata", {})))
|
||||
|
||||
capture = MetadataCapture()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer, capture],
|
||||
langsmith_inheritable_metadata={"tracer_only": "yes"},
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(1, {"callbacks": cm})
|
||||
|
||||
# Non-tracer handler should NOT see langsmith_inheritable_metadata
|
||||
assert len(received_metadata) >= 1
|
||||
for md in received_metadata:
|
||||
assert "tracer_only" not in md
|
||||
|
||||
# But the tracer's posted runs SHOULD have it
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) >= 1
|
||||
for post in posts:
|
||||
post_md = post.get("extra", {}).get("metadata", {})
|
||||
assert post_md["tracer_only"] == "yes"
|
||||
|
||||
def test_no_langsmith_inheritable_metadata_is_noop(self) -> None:
|
||||
"""Passing langsmith_inheritable_metadata=None does not alter tracer state."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata=None,
|
||||
)
|
||||
lc_tracer = next(h for h in cm.handlers if isinstance(h, LangChainTracer))
|
||||
assert lc_tracer is tracer
|
||||
assert tracer.tracing_metadata is None
|
||||
|
||||
def test_langsmith_inheritable_tags_applied_via_configure(self) -> None:
|
||||
"""langsmith_inheritable_tags flow to a copied tracer."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
tracer.tags = ["existing"]
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_tags=["tenant:alpha", "existing"],
|
||||
)
|
||||
lc_tracer = next(h for h in cm.handlers if isinstance(h, LangChainTracer))
|
||||
assert lc_tracer is not tracer
|
||||
assert lc_tracer.tags == ["existing", "tenant:alpha"]
|
||||
assert tracer.tags == ["existing"]
|
||||
|
||||
def test_inheritable_tags_do_not_affect_non_tracer_handlers(self) -> None:
|
||||
"""langsmith_inheritable_tags only apply to tracers."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
received_tags: list[list[str]] = []
|
||||
|
||||
class TagCapture(BaseCallbackHandler):
|
||||
def on_chain_start(self, *_args: Any, **kwargs: Any) -> None:
|
||||
received_tags.append(list(kwargs.get("tags", [])))
|
||||
|
||||
capture = TagCapture()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer, capture],
|
||||
langsmith_inheritable_tags=["tracer-only"],
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(1, {"callbacks": cm})
|
||||
|
||||
assert received_tags
|
||||
assert all("tracer-only" not in tags for tags in received_tags)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert posts
|
||||
assert all("tracer-only" in post.get("tags", []) for post in posts)
|
||||
|
||||
def test_langsmith_inheritable_metadata_copies_handlers_without_mutating_original(
|
||||
self,
|
||||
) -> None:
|
||||
"""Configured manager copies tracers and leaves the original unchanged."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"env": "prod"},
|
||||
)
|
||||
handler_tracer = next(h for h in cm.handlers if isinstance(h, LangChainTracer))
|
||||
inheritable_tracer = next(
|
||||
h for h in cm.inheritable_handlers if isinstance(h, LangChainTracer)
|
||||
)
|
||||
assert handler_tracer is not tracer
|
||||
assert inheritable_tracer is not tracer
|
||||
assert tracer.tracing_metadata is None
|
||||
|
||||
def test_langsmith_inheritable_metadata_configure_isolated_per_manager(
|
||||
self,
|
||||
) -> None:
|
||||
"""Separate configure calls keep tracer-only defaults isolated."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
alpha_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"tenant": "alpha"},
|
||||
)
|
||||
beta_manager = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"tenant": "beta"},
|
||||
)
|
||||
|
||||
alpha_tracer = next(
|
||||
handler
|
||||
for handler in alpha_manager.handlers
|
||||
if isinstance(handler, LangChainTracer)
|
||||
)
|
||||
beta_tracer = next(
|
||||
handler
|
||||
for handler in beta_manager.handlers
|
||||
if isinstance(handler, LangChainTracer)
|
||||
)
|
||||
|
||||
assert tracer.tracing_metadata is None
|
||||
assert alpha_tracer is not tracer
|
||||
assert beta_tracer is not tracer
|
||||
assert alpha_tracer is not beta_tracer
|
||||
assert alpha_tracer.tracing_metadata == {"tenant": "alpha"}
|
||||
assert beta_tracer.tracing_metadata == {"tenant": "beta"}
|
||||
assert alpha_tracer.run_map is tracer.run_map
|
||||
assert beta_tracer.run_map is tracer.run_map
|
||||
assert alpha_tracer.order_map is tracer.order_map
|
||||
assert beta_tracer.order_map is tracer.order_map
|
||||
|
||||
def test_inheritable_metadata_concurrent_invocations_remain_isolated(
|
||||
self,
|
||||
) -> None:
|
||||
"""Parallel invocations through copied tracers keep metadata separated."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
barrier = threading.Barrier(2)
|
||||
|
||||
@traceable
|
||||
def traced_leaf(x: int) -> int:
|
||||
barrier.wait()
|
||||
return x
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return traced_leaf(x)
|
||||
|
||||
def invoke_for_tenant(tenant: str, value: int) -> int:
|
||||
callbacks = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
langsmith_inheritable_metadata={"tenant": tenant},
|
||||
)
|
||||
return my_func.invoke(value, {"callbacks": callbacks})
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
list(executor.map(invoke_for_tenant, ["alpha", "beta"], [1, 2]))
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 4
|
||||
assert {post["name"] for post in posts} == {"my_func", "traced_leaf"}
|
||||
my_func_posts = [post for post in posts if post["name"] == "my_func"]
|
||||
assert len(my_func_posts) == 2
|
||||
assert {
|
||||
post.get("extra", {}).get("metadata", {}).get("tenant")
|
||||
for post in my_func_posts
|
||||
} == {"alpha", "beta"}
|
||||
assert tracer.run_map == {}
|
||||
assert len(tracer.order_map) == 2
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import time
|
||||
import unittest.mock
|
||||
@@ -15,6 +16,7 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.tracers.langchain import (
|
||||
LangChainTracer,
|
||||
_get_usage_metadata_from_generations,
|
||||
_patch_missing_metadata,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@@ -696,3 +698,206 @@ def test_on_chain_error_updates_when_not_defers_inputs() -> None:
|
||||
# Should call update (PATCH), not persist (POST) for normal inputs
|
||||
assert not persist_called
|
||||
assert update_called
|
||||
|
||||
|
||||
class TestPatchMissingMetadata:
|
||||
"""Tests for `_patch_missing_metadata` and tracer metadata behavior."""
|
||||
|
||||
@staticmethod
|
||||
def _make_tracer(
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
return LangChainTracer(client=client, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def _make_run(
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Run:
|
||||
return Run(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
inputs={},
|
||||
run_type="chain",
|
||||
extra={"metadata": metadata or {}},
|
||||
)
|
||||
|
||||
def test_adds_metadata_when_run_has_none(self) -> None:
|
||||
"""Tracer metadata fills in when the run has no matching keys."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "service": "api"})
|
||||
run = self._make_run()
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata["env"] == "prod"
|
||||
assert run.metadata["service"] == "api"
|
||||
|
||||
def test_does_not_overwrite_existing_keys(self) -> None:
|
||||
"""Config metadata takes precedence over tracer metadata."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "service": "api"})
|
||||
run = self._make_run(metadata={"env": "staging"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata["env"] == "staging"
|
||||
assert run.metadata["service"] == "api"
|
||||
|
||||
def test_noop_when_tracer_has_no_metadata(self) -> None:
|
||||
"""No-op when the tracer has no metadata configured."""
|
||||
tracer = self._make_tracer(metadata=None)
|
||||
run = self._make_run(metadata={"existing": "value"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {"existing": "value"}
|
||||
|
||||
def test_noop_when_all_keys_already_present(self) -> None:
|
||||
"""No-op when every tracer key already exists in the run."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod"})
|
||||
run = self._make_run(metadata={"env": "dev"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {"env": "dev"}
|
||||
|
||||
def test_merges_disjoint_keys(self) -> None:
|
||||
"""Disjoint keys from tracer and config are all present after patching."""
|
||||
tracer = self._make_tracer(metadata={"tracer_key": "tracer_val"})
|
||||
run = self._make_run(metadata={"config_key": "config_val"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {
|
||||
"tracer_key": "tracer_val",
|
||||
"config_key": "config_val",
|
||||
}
|
||||
|
||||
def test_persist_run_single_applies_tracer_metadata(self) -> None:
|
||||
"""End-to-end: `_persist_run_single` calls `_patch_missing_metadata`."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod"})
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_chain_start(
|
||||
{"name": "test_chain"},
|
||||
{"input": "hello"},
|
||||
run_id=run_id,
|
||||
)
|
||||
run = tracer.run_map[str(run_id)]
|
||||
|
||||
with unittest.mock.patch.object(Run, "post"):
|
||||
tracer._persist_run_single(run)
|
||||
|
||||
assert run.metadata.get("env") == "prod"
|
||||
|
||||
def test_persist_run_single_config_metadata_wins(self) -> None:
|
||||
"""Config metadata is not overwritten by tracer metadata during persist."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "extra": "from_tracer"})
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160b")
|
||||
tracer.on_chain_start(
|
||||
{"name": "test_chain"},
|
||||
{"input": "hello"},
|
||||
run_id=run_id,
|
||||
metadata={"env": "staging"},
|
||||
)
|
||||
run = tracer.run_map[str(run_id)]
|
||||
|
||||
with unittest.mock.patch.object(Run, "post"):
|
||||
tracer._persist_run_single(run)
|
||||
|
||||
assert run.metadata["env"] == "staging"
|
||||
assert run.metadata["extra"] == "from_tracer"
|
||||
|
||||
|
||||
class TestTracerMetadataCloning:
|
||||
"""Tests for LangChainTracer metadata cloning helpers."""
|
||||
|
||||
@staticmethod
|
||||
def _make_tracer(
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
return LangChainTracer(client=client, metadata=metadata)
|
||||
|
||||
def test_copy_with_metadata_defaults_copies_configuration(self) -> None:
|
||||
"""Copied tracer keeps stable configuration but not identity."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
tracer.project_name = "project"
|
||||
tracer.tags = ["tag"]
|
||||
|
||||
copied = tracer.copy_with_metadata_defaults(metadata={"service": "api"})
|
||||
|
||||
assert copied is not tracer
|
||||
assert copied.client is tracer.client
|
||||
assert copied.project_name == "project"
|
||||
assert copied.tags == ["tag"]
|
||||
assert copied.tags is tracer.tags
|
||||
assert copied.tracing_metadata == {"env": "staging", "service": "api"}
|
||||
assert copied.run_map is tracer.run_map
|
||||
assert copied.order_map is tracer.order_map
|
||||
assert copied.run_has_token_event_map == {}
|
||||
|
||||
def test_copy_with_metadata_defaults_does_not_mutate_original(self) -> None:
|
||||
"""Metadata-default cloning leaves the source tracer unchanged."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
|
||||
copied = tracer.copy_with_metadata_defaults(metadata={"service": "api"})
|
||||
|
||||
assert tracer.tracing_metadata == {"env": "staging"}
|
||||
assert copied.tracing_metadata == {"env": "staging", "service": "api"}
|
||||
|
||||
def test_copy_with_metadata_defaults_none_preserves_configuration(self) -> None:
|
||||
"""Copying without new metadata preserves metadata and shared run state."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
copied = tracer.copy_with_metadata_defaults(metadata=None)
|
||||
|
||||
assert copied is not tracer
|
||||
assert copied.tracing_metadata == {"env": "staging"}
|
||||
assert copied.run_map is tracer.run_map
|
||||
assert copied.order_map is tracer.order_map
|
||||
|
||||
def test_copy_with_metadata_defaults_threadsafe(self) -> None:
|
||||
"""Concurrent metadata-default copies do not mutate each other or the source."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
|
||||
def copy_for_service(service: str) -> dict[str, str]:
|
||||
copied = tracer.copy_with_metadata_defaults(metadata={"service": service})
|
||||
assert copied is not tracer
|
||||
return copied.tracing_metadata or {}
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
metadata_values = list(executor.map(copy_for_service, ["api", "worker"]))
|
||||
|
||||
assert tracer.tracing_metadata == {"env": "staging"}
|
||||
assert {metadata["service"] for metadata in metadata_values} == {
|
||||
"api",
|
||||
"worker",
|
||||
}
|
||||
assert all(metadata["env"] == "staging" for metadata in metadata_values)
|
||||
|
||||
def test_copy_with_metadata_defaults_threadsafe_with_existing_shared_state(
|
||||
self,
|
||||
) -> None:
|
||||
"""Concurrent copies preserve pre-populated shared run state."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
run_id = uuid.uuid4()
|
||||
tracer.run_map["existing"] = unittest.mock.MagicMock()
|
||||
tracer.order_map[run_id] = (run_id, f"prefix.{run_id}")
|
||||
|
||||
def copy_for_service(service: str) -> LangChainTracer:
|
||||
copied = tracer.copy_with_metadata_defaults(metadata={"service": service})
|
||||
assert copied.run_map is tracer.run_map
|
||||
assert copied.order_map is tracer.order_map
|
||||
return copied
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
copied_tracers = list(executor.map(copy_for_service, ["api", "worker"]))
|
||||
|
||||
assert tracer.run_map.keys() == {"existing"}
|
||||
assert tracer.order_map == {run_id: (run_id, f"prefix.{run_id}")}
|
||||
copied_services = {
|
||||
copied.tracing_metadata["service"]
|
||||
for copied in copied_tracers
|
||||
if copied.tracing_metadata is not None
|
||||
}
|
||||
assert copied_services == {"api", "worker"}
|
||||
|
||||
Reference in New Issue
Block a user