mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-12 23:42:51 +00:00
Compare commits
6 Commits
langchain-
...
wfh/replic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94c90d5899 | ||
|
|
cf647b6d14 | ||
|
|
c08bf05e28 | ||
|
|
fa7b11d9dd | ||
|
|
e9dc3ee4c8 | ||
|
|
b558f081d3 |
@@ -6,7 +6,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from tenacity import RetryCallState
|
||||
@@ -948,6 +948,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
inheritable_tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
inheritable_metadata: dict[str, Any] | None = None,
|
||||
tracing_metadata: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize callback manager.
|
||||
|
||||
@@ -959,6 +960,9 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
inheritable_tags: The inheritable tags.
|
||||
metadata: The metadata.
|
||||
inheritable_metadata: The inheritable metadata.
|
||||
tracing_metadata: Per-invocation default metadata merged into every run
|
||||
started by this manager. Keys already present in a run's metadata
|
||||
are not overwritten.
|
||||
"""
|
||||
self.handlers: list[BaseCallbackHandler] = handlers
|
||||
self.inheritable_handlers: list[BaseCallbackHandler] = (
|
||||
@@ -969,6 +973,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
self.tracing_metadata: Mapping[str, str] | None = tracing_metadata
|
||||
|
||||
def copy(self) -> Self:
|
||||
"""Return a copy of the callback manager."""
|
||||
@@ -980,6 +985,7 @@ class BaseCallbackManager(CallbackManagerMixin):
|
||||
inheritable_tags=self.inheritable_tags.copy(),
|
||||
metadata=self.metadata.copy(),
|
||||
inheritable_metadata=self.inheritable_metadata.copy(),
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
def merge(self, other: BaseCallbackManager) -> Self:
|
||||
|
||||
@@ -7,7 +7,7 @@ import atexit
|
||||
import functools
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from contextvars import copy_context
|
||||
@@ -269,6 +269,8 @@ def handle_event(
|
||||
**kwargs: The keyword arguments to pass to the event handler
|
||||
|
||||
"""
|
||||
# Pop tracing_metadata; only forwarded to handlers that opt in.
|
||||
tracing_metadata = kwargs.pop("_tracing_metadata", None)
|
||||
coros: list[Coroutine[Any, Any, Any]] = []
|
||||
|
||||
try:
|
||||
@@ -278,7 +280,16 @@ def handle_event(
|
||||
if ignore_condition_name is None or not getattr(
|
||||
handler, ignore_condition_name
|
||||
):
|
||||
event = getattr(handler, event_name)(*args, **kwargs)
|
||||
handler_kwargs = (
|
||||
{
|
||||
**kwargs,
|
||||
"tracing_metadata": tracing_metadata,
|
||||
}
|
||||
if tracing_metadata
|
||||
and getattr(handler, "_accepts_tracing_metadata", False)
|
||||
else kwargs
|
||||
)
|
||||
event = getattr(handler, event_name)(*args, **handler_kwargs)
|
||||
if asyncio.iscoroutine(event):
|
||||
coros.append(event)
|
||||
except NotImplementedError as e:
|
||||
@@ -370,6 +381,10 @@ async def _ahandle_event_for_handler(
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Pop tracing_metadata; only forwarded to handlers that opt in.
|
||||
tracing_metadata = kwargs.pop("_tracing_metadata", None)
|
||||
if tracing_metadata and getattr(handler, "_accepts_tracing_metadata", False):
|
||||
kwargs = {**kwargs, "tracing_metadata": tracing_metadata}
|
||||
try:
|
||||
if ignore_condition_name is None or not getattr(handler, ignore_condition_name):
|
||||
event = getattr(handler, event_name)
|
||||
@@ -466,6 +481,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
inheritable_tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
inheritable_metadata: dict[str, Any] | None = None,
|
||||
tracing_metadata: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the run manager.
|
||||
|
||||
@@ -478,6 +494,8 @@ class BaseRunManager(RunManagerMixin):
|
||||
inheritable_tags: The list of inheritable tags.
|
||||
metadata: The metadata.
|
||||
inheritable_metadata: The inheritable metadata.
|
||||
tracing_metadata: Per-invocation default metadata merged into runs
|
||||
started by tracer handlers.
|
||||
|
||||
"""
|
||||
self.run_id = run_id
|
||||
@@ -488,6 +506,7 @@ class BaseRunManager(RunManagerMixin):
|
||||
self.inheritable_tags = inheritable_tags or []
|
||||
self.metadata = metadata or {}
|
||||
self.inheritable_metadata = inheritable_metadata or {}
|
||||
self.tracing_metadata: Mapping[str, str] | None = tracing_metadata
|
||||
|
||||
@classmethod
|
||||
def get_noop_manager(cls) -> Self:
|
||||
@@ -578,6 +597,7 @@ class ParentRunManager(RunManager):
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
manager.tracing_metadata = self.tracing_metadata
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], inherit=False)
|
||||
return manager
|
||||
@@ -662,6 +682,7 @@ class AsyncParentRunManager(AsyncRunManager):
|
||||
manager.set_handlers(self.inheritable_handlers)
|
||||
manager.add_tags(self.inheritable_tags)
|
||||
manager.add_metadata(self.inheritable_metadata)
|
||||
manager.tracing_metadata = self.tracing_metadata
|
||||
if tag is not None:
|
||||
manager.add_tags([tag], inherit=False)
|
||||
return manager
|
||||
@@ -1335,6 +1356,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1348,6 +1370,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1389,6 +1412,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1402,6 +1426,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1438,6 +1463,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1450,6 +1476,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -1498,6 +1525,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
inputs=inputs,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1510,6 +1538,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -1546,6 +1575,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1558,6 +1588,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
def on_custom_event(
|
||||
@@ -1614,6 +1645,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
local_tags: list[str] | None = None,
|
||||
inheritable_metadata: dict[str, Any] | None = None,
|
||||
local_metadata: dict[str, Any] | None = None,
|
||||
*,
|
||||
tracing_metadata: Mapping[str, str] | None = None,
|
||||
) -> CallbackManager:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -1625,6 +1658,8 @@ class CallbackManager(BaseCallbackManager):
|
||||
local_tags: The local tags.
|
||||
inheritable_metadata: The inheritable metadata.
|
||||
local_metadata: The local metadata.
|
||||
tracing_metadata: Default metadata merged into runs started by
|
||||
tracer handlers. Existing run metadata keys are not overwritten.
|
||||
|
||||
Returns:
|
||||
The configured callback manager.
|
||||
@@ -1638,6 +1673,7 @@ class CallbackManager(BaseCallbackManager):
|
||||
inheritable_metadata,
|
||||
local_metadata,
|
||||
verbose=verbose,
|
||||
tracing_metadata=tracing_metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -1826,6 +1862,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -1841,6 +1878,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
@@ -1855,6 +1893,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1909,6 +1948,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
if handler.run_inline:
|
||||
@@ -1926,6 +1966,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1970,6 +2011,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1982,6 +2024,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -2018,6 +2061,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2030,6 +2074,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
async def on_custom_event(
|
||||
@@ -2110,6 +2155,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
parent_run_id=self.parent_run_id,
|
||||
tags=self.tags,
|
||||
metadata=self.metadata,
|
||||
_tracing_metadata=self.tracing_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2122,6 +2168,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_tags=self.inheritable_tags,
|
||||
metadata=self.metadata,
|
||||
inheritable_metadata=self.inheritable_metadata,
|
||||
tracing_metadata=self.tracing_metadata,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -2134,6 +2181,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
local_tags: list[str] | None = None,
|
||||
inheritable_metadata: dict[str, Any] | None = None,
|
||||
local_metadata: dict[str, Any] | None = None,
|
||||
*,
|
||||
tracing_metadata: Mapping[str, str] | None = None,
|
||||
) -> AsyncCallbackManager:
|
||||
"""Configure the async callback manager.
|
||||
|
||||
@@ -2145,6 +2194,8 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
local_tags: The local tags.
|
||||
inheritable_metadata: The inheritable metadata.
|
||||
local_metadata: The local metadata.
|
||||
tracing_metadata: Default metadata merged into runs started by
|
||||
tracer handlers. Existing run metadata keys are not overwritten.
|
||||
|
||||
Returns:
|
||||
The configured async callback manager.
|
||||
@@ -2158,6 +2209,7 @@ class AsyncCallbackManager(BaseCallbackManager):
|
||||
inheritable_metadata,
|
||||
local_metadata,
|
||||
verbose=verbose,
|
||||
tracing_metadata=tracing_metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -2304,6 +2356,7 @@ def _configure(
|
||||
local_metadata: dict[str, Any] | None = None,
|
||||
*,
|
||||
verbose: bool = False,
|
||||
tracing_metadata: Mapping[str, str] | None = None,
|
||||
) -> T:
|
||||
"""Configure the callback manager.
|
||||
|
||||
@@ -2316,6 +2369,8 @@ def _configure(
|
||||
inheritable_metadata: The inheritable metadata.
|
||||
local_metadata: The local metadata.
|
||||
verbose: Whether to enable verbose mode.
|
||||
tracing_metadata: Default metadata merged into runs started by
|
||||
tracer handlers. Existing run metadata keys are not overwritten.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If `LANGCHAIN_TRACING` is set but `LANGCHAIN_TRACING_V2` is not.
|
||||
@@ -2336,7 +2391,7 @@ def _configure(
|
||||
from langchain_core.tracers.stdout import ConsoleCallbackHandler # noqa: PLC0415
|
||||
|
||||
tracing_context = get_tracing_context()
|
||||
tracing_metadata = tracing_context["metadata"]
|
||||
context_metadata = tracing_context["metadata"]
|
||||
tracing_tags = tracing_context["tags"]
|
||||
run_tree: Run | None = tracing_context["parent"]
|
||||
parent_run_id = None if run_tree is None else run_tree.id
|
||||
@@ -2373,6 +2428,7 @@ def _configure(
|
||||
inheritable_tags=inheritable_callbacks.inheritable_tags.copy(),
|
||||
metadata=inheritable_callbacks.metadata.copy(),
|
||||
inheritable_metadata=inheritable_callbacks.inheritable_metadata.copy(),
|
||||
tracing_metadata=(inheritable_callbacks.tracing_metadata),
|
||||
)
|
||||
local_handlers_ = (
|
||||
local_callbacks
|
||||
@@ -2387,8 +2443,8 @@ def _configure(
|
||||
if inheritable_metadata or local_metadata:
|
||||
callback_manager.add_metadata(inheritable_metadata or {})
|
||||
callback_manager.add_metadata(local_metadata or {}, inherit=False)
|
||||
if tracing_metadata:
|
||||
callback_manager.add_metadata(tracing_metadata.copy())
|
||||
if context_metadata:
|
||||
callback_manager.add_metadata(context_metadata.copy())
|
||||
if tracing_tags:
|
||||
callback_manager.add_tags(tracing_tags.copy())
|
||||
|
||||
@@ -2479,6 +2535,8 @@ def _configure(
|
||||
for handler in callback_manager.handlers
|
||||
):
|
||||
callback_manager.add_handler(var_handler, inheritable)
|
||||
if tracing_metadata:
|
||||
callback_manager.tracing_metadata = tracing_metadata
|
||||
return callback_manager
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ logger = logging.getLogger(__name__)
|
||||
class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
"""Base interface for tracers."""
|
||||
|
||||
# Opt in to receiving per-invocation tracing_metadata via handle_event.
|
||||
_accepts_tracing_metadata = True
|
||||
|
||||
@abstractmethod
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
@@ -59,6 +62,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
parent_run_id: UUID | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a chat model run.
|
||||
@@ -77,6 +81,8 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
parent_run_id: The parent run ID.
|
||||
metadata: The metadata for the run.
|
||||
name: The name of the run.
|
||||
tracing_metadata: Per-invocation default metadata to merge
|
||||
into the run. Existing run metadata keys are not overwritten.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
@@ -92,6 +98,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = chat_model_run.extra.get("metadata") or {}
|
||||
chat_model_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
return chat_model_run
|
||||
@@ -106,6 +115,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
parent_run_id: UUID | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run.
|
||||
@@ -118,6 +128,8 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
parent_run_id: The parent run ID.
|
||||
metadata: The metadata for the run.
|
||||
name: The name of the run.
|
||||
tracing_metadata: Per-invocation default metadata to merge
|
||||
into the run. Existing run metadata keys are not overwritten.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
@@ -133,6 +145,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = llm_run.extra.get("metadata") or {}
|
||||
llm_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(llm_run)
|
||||
self._on_llm_start(llm_run)
|
||||
return llm_run
|
||||
@@ -260,6 +275,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
metadata: dict[str, Any] | None = None,
|
||||
run_type: str | None = None,
|
||||
name: str | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a chain run.
|
||||
@@ -273,6 +289,8 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
metadata: The metadata for the run.
|
||||
run_type: The type of the run.
|
||||
name: The name of the run.
|
||||
tracing_metadata: Per-invocation default metadata to merge
|
||||
into the run. Existing run metadata keys are not overwritten.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
@@ -289,6 +307,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = chain_run.extra.get("metadata") or {}
|
||||
chain_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(chain_run)
|
||||
self._on_chain_start(chain_run)
|
||||
return chain_run
|
||||
@@ -362,6 +383,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
metadata: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for a tool run.
|
||||
@@ -375,6 +397,8 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
metadata: The metadata for the run.
|
||||
name: The name of the run.
|
||||
inputs: The inputs for the tool.
|
||||
tracing_metadata: Per-invocation default metadata to merge
|
||||
into the run. Existing run metadata keys are not overwritten.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
@@ -391,6 +415,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
inputs=inputs,
|
||||
**kwargs,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = tool_run.extra.get("metadata") or {}
|
||||
tool_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(tool_run)
|
||||
self._on_tool_start(tool_run)
|
||||
return tool_run
|
||||
@@ -451,6 +478,7 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
tags: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Run when the `Retriever` starts running.
|
||||
@@ -463,6 +491,8 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
tags: The tags for the run.
|
||||
metadata: The metadata for the run.
|
||||
name: The name of the run.
|
||||
tracing_metadata: Per-invocation default metadata to merge
|
||||
into the run. Existing run metadata keys are not overwritten.
|
||||
**kwargs: Additional arguments.
|
||||
|
||||
Returns:
|
||||
@@ -478,6 +508,9 @@ class BaseTracer(_TracerCore, BaseCallbackHandler, ABC):
|
||||
name=name,
|
||||
**kwargs,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = retrieval_run.extra.get("metadata") or {}
|
||||
retrieval_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
return retrieval_run
|
||||
|
||||
@@ -27,6 +27,8 @@ from langchain_core.tracers.base import BaseTracer
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
|
||||
@@ -124,6 +126,8 @@ class LangChainTracer(BaseTracer):
|
||||
project_name: str | None = None,
|
||||
client: Client | None = None,
|
||||
tags: list[str] | None = None,
|
||||
*,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the LangChain tracer.
|
||||
@@ -139,6 +143,9 @@ class LangChainTracer(BaseTracer):
|
||||
tags: The tags.
|
||||
|
||||
Defaults to an empty list.
|
||||
metadata: Additional metadata to include if it isn't already in the run.
|
||||
|
||||
Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
@@ -150,6 +157,24 @@ class LangChainTracer(BaseTracer):
|
||||
self.tags = tags or []
|
||||
self.latest_run: Run | None = None
|
||||
self.run_has_token_event_map: dict[str, bool] = {}
|
||||
self.tracing_metadata: dict[str, str] | None = (
|
||||
dict(metadata) if metadata is not None else None
|
||||
)
|
||||
|
||||
def set_defaults(self, *, metadata: Mapping[str, str] | None = None) -> None:
|
||||
"""Set default tracer values, only filling in keys not already present.
|
||||
|
||||
Args:
|
||||
metadata: Default metadata to include on runs. Keys already present
|
||||
in `tracing_metadata` are not overwritten.
|
||||
"""
|
||||
if metadata is not None:
|
||||
if self.tracing_metadata is None:
|
||||
self.tracing_metadata = dict(metadata)
|
||||
else:
|
||||
for k, v in metadata.items():
|
||||
if k not in self.tracing_metadata:
|
||||
self.tracing_metadata[k] = v
|
||||
|
||||
def _start_trace(self, run: Run) -> None:
|
||||
if self.project_name:
|
||||
@@ -176,6 +201,7 @@ class LangChainTracer(BaseTracer):
|
||||
parent_run_id: UUID | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
name: str | None = None,
|
||||
tracing_metadata: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Run:
|
||||
"""Start a trace for an LLM run.
|
||||
@@ -188,6 +214,7 @@ class LangChainTracer(BaseTracer):
|
||||
parent_run_id: The parent run ID.
|
||||
metadata: The metadata.
|
||||
name: The name.
|
||||
tracing_metadata: Per-invocation default metadata.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
@@ -208,6 +235,9 @@ class LangChainTracer(BaseTracer):
|
||||
tags=tags,
|
||||
name=name,
|
||||
)
|
||||
if tracing_metadata:
|
||||
existing = chat_model_run.extra.get("metadata") or {}
|
||||
chat_model_run.extra["metadata"] = {**tracing_metadata, **existing}
|
||||
self._start_trace(chat_model_run)
|
||||
self._on_chat_model_start(chat_model_run)
|
||||
return chat_model_run
|
||||
@@ -263,6 +293,7 @@ class LangChainTracer(BaseTracer):
|
||||
try:
|
||||
run.extra["runtime"] = get_runtime_environment()
|
||||
run.tags = self._get_tags(run)
|
||||
_patch_missing_metadata(self, run)
|
||||
if run.ls_client is not self.client:
|
||||
run.ls_client = self.client
|
||||
run.post()
|
||||
@@ -398,3 +429,18 @@ class LangChainTracer(BaseTracer):
|
||||
"""Wait for the given futures to complete."""
|
||||
if self.client is not None:
|
||||
self.client.flush()
|
||||
|
||||
|
||||
def _patch_missing_metadata(self: LangChainTracer, run: Run) -> None:
|
||||
# Apply constructor-set tracing_metadata as defaults (fill-not-overwrite).
|
||||
if not self.tracing_metadata:
|
||||
return
|
||||
metadata = run.metadata
|
||||
patched = None
|
||||
for k, v in self.tracing_metadata.items():
|
||||
if k not in metadata:
|
||||
if patched is None:
|
||||
# Copy on first miss to avoid mutating the shared dict.
|
||||
patched = {**metadata}
|
||||
run.extra["metadata"] = patched
|
||||
patched[k] = v
|
||||
|
||||
@@ -12,13 +12,13 @@ from langsmith import Client, RunTree, get_current_run_tree, traceable
|
||||
from langsmith.run_helpers import tracing_context
|
||||
from langsmith.utils import get_env_var
|
||||
|
||||
from langchain_core.callbacks.base import BaseCallbackHandler
|
||||
from langchain_core.callbacks.manager import CallbackManager
|
||||
from langchain_core.runnables.base import RunnableLambda, RunnableParallel
|
||||
from langchain_core.tracers.langchain import LangChainTracer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from collections.abc import AsyncGenerator, Callable, Coroutine, Generator, Mapping
|
||||
|
||||
|
||||
def _get_posts(client: Client) -> list[dict[str, Any]]:
|
||||
@@ -43,12 +43,15 @@ def _get_posts(client: Client) -> list[dict[str, Any]]:
|
||||
def _create_tracer_with_mocked_client(
|
||||
project_name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
metadata: Mapping[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
mock_session = MagicMock()
|
||||
mock_client_ = Client(
|
||||
session=mock_session, api_key="test", auto_batch_tracing=False
|
||||
)
|
||||
return LangChainTracer(client=mock_client_, project_name=project_name, tags=tags)
|
||||
return LangChainTracer(
|
||||
client=mock_client_, project_name=project_name, tags=tags, metadata=metadata
|
||||
)
|
||||
|
||||
|
||||
def test_tracing_context() -> None:
|
||||
@@ -508,3 +511,261 @@ def test_tree_is_constructed(parent_type: Literal["ls", "lc"]) -> None:
|
||||
assert "afoo" in kitten_run.tags # type: ignore[operator]
|
||||
assert grandchild_run is not None
|
||||
assert kitten_run.dotted_order.startswith(grandchild_run.dotted_order)
|
||||
|
||||
|
||||
class TestTracerMetadataThroughInvoke:
|
||||
"""Tests for tracer metadata merging through invoke calls."""
|
||||
|
||||
def test_tracer_metadata_applied_to_all_runs(self) -> None:
|
||||
"""Tracer metadata appears on every run when no config metadata is set."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"env": "prod", "service": "api"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def child(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def parent(x: int) -> int:
|
||||
return child.invoke(x)
|
||||
|
||||
parent.invoke(1, {"callbacks": [tracer]})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 2
|
||||
for post in posts:
|
||||
md = post.get("extra", {}).get("metadata", {})
|
||||
assert md.get("env") == "prod", f"run {post['name']} missing env"
|
||||
assert md.get("service") == "api", f"run {post['name']} missing service"
|
||||
|
||||
def test_config_metadata_takes_precedence(self) -> None:
|
||||
"""Config metadata wins over tracer metadata for overlapping keys."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"env": "prod", "tracer_only": "yes"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer],
|
||||
"metadata": {"env": "staging", "config_only": "yes"},
|
||||
},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
# Config wins for overlapping key
|
||||
assert md["env"] == "staging"
|
||||
# Both non-overlapping keys are present
|
||||
assert md["tracer_only"] == "yes"
|
||||
assert md["config_only"] == "yes"
|
||||
|
||||
def test_nested_calls_inherit_config_metadata(self) -> None:
|
||||
"""Child runs inherit config metadata; tracer metadata fills gaps."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"tracer_key": "tracer_val"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def child(x: int) -> int:
|
||||
return x + 1
|
||||
|
||||
@RunnableLambda
|
||||
def parent(x: int) -> int:
|
||||
return child.invoke(x)
|
||||
|
||||
parent.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer],
|
||||
"metadata": {"config_key": "config_val"},
|
||||
},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 2
|
||||
name_to_md = {
|
||||
post["name"]: post.get("extra", {}).get("metadata", {}) for post in posts
|
||||
}
|
||||
# Both parent and child should have config metadata (inherited)
|
||||
# and tracer metadata (patched in)
|
||||
for name, md in name_to_md.items():
|
||||
assert md.get("config_key") == "config_val", f"{name} missing config_key"
|
||||
assert md.get("tracer_key") == "tracer_val", f"{name} missing tracer_key"
|
||||
|
||||
def test_tracer_metadata_not_leaked_to_sibling_handlers(self) -> None:
|
||||
"""Tracer metadata does not leak to other callback handlers.
|
||||
|
||||
`_patch_missing_metadata` copies the metadata dict before patching,
|
||||
so the callback manager's shared metadata dict is not mutated.
|
||||
Other handlers should only see config metadata, not tracer metadata.
|
||||
"""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"tracer_key": "tracer_val"}
|
||||
)
|
||||
|
||||
received_metadata: list[dict[str, Any]] = []
|
||||
|
||||
class MetadataCapture(BaseCallbackHandler):
|
||||
"""Callback handler that records metadata from chain events."""
|
||||
|
||||
def on_chain_start(self, *_args: Any, **kwargs: Any) -> None:
|
||||
received_metadata.append(dict(kwargs.get("metadata", {})))
|
||||
|
||||
capture = MetadataCapture()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{
|
||||
"callbacks": [tracer, capture],
|
||||
"metadata": {"shared_key": "shared_val"},
|
||||
},
|
||||
)
|
||||
|
||||
assert len(received_metadata) >= 1
|
||||
for md in received_metadata:
|
||||
assert md["shared_key"] == "shared_val"
|
||||
assert "tracer_key" not in md
|
||||
|
||||
# But the posted run DOES have tracer metadata
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) >= 1
|
||||
for post in posts:
|
||||
post_md = post.get("extra", {}).get("metadata", {})
|
||||
assert post_md["shared_key"] == "shared_val"
|
||||
assert post_md["tracer_key"] == "tracer_val"
|
||||
|
||||
def test_tracer_metadata_with_no_config_metadata(self) -> None:
|
||||
"""When no config metadata is set, tracer metadata is the sole source."""
|
||||
tracer = _create_tracer_with_mocked_client(
|
||||
metadata={"only_from_tracer": "value"}
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(1, {"callbacks": [tracer]})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["only_from_tracer"] == "value"
|
||||
|
||||
def test_empty_tracer_metadata_does_not_interfere(self) -> None:
|
||||
"""Tracer with no metadata does not interfere with config metadata."""
|
||||
tracer = _create_tracer_with_mocked_client(metadata=None)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(
|
||||
1,
|
||||
{"callbacks": [tracer], "metadata": {"config_key": "config_val"}},
|
||||
)
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["config_key"] == "config_val"
|
||||
|
||||
|
||||
class TestTracingMetadataInConfigure:
|
||||
"""Tests for `tracing_metadata` parameter in `CallbackManager.configure()`."""
|
||||
|
||||
def test_tracing_metadata_applied_via_configure(self) -> None:
|
||||
"""tracing_metadata flows through configure to the CallbackManager."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
tracing_metadata={"env": "prod", "service": "api"},
|
||||
)
|
||||
# Metadata is stored on the manager, not mutated on the shared tracer.
|
||||
assert cm.tracing_metadata == {"env": "prod", "service": "api"}
|
||||
# The shared tracer instance is NOT mutated.
|
||||
assert tracer.tracing_metadata is None
|
||||
|
||||
def test_tracing_metadata_does_not_overwrite_tracer_metadata(self) -> None:
|
||||
"""Tracer's own metadata takes precedence over tracing_metadata."""
|
||||
tracer = _create_tracer_with_mocked_client(metadata={"env": "staging"})
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
tracing_metadata={"env": "prod", "service": "api"},
|
||||
)
|
||||
# The shared tracer instance is NOT mutated.
|
||||
assert tracer.tracing_metadata == {"env": "staging"}
|
||||
# Per-invocation defaults are on the manager.
|
||||
assert cm.tracing_metadata == {"env": "prod", "service": "api"}
|
||||
|
||||
def test_tracing_metadata_end_to_end(self) -> None:
|
||||
"""tracing_metadata in configure propagates to posted runs."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
tracing_metadata={"env": "prod"},
|
||||
)
|
||||
my_func.invoke(1, {"callbacks": cm})
|
||||
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) == 1
|
||||
md = posts[0].get("extra", {}).get("metadata", {})
|
||||
assert md["env"] == "prod"
|
||||
|
||||
def test_tracing_metadata_does_not_affect_non_tracer_handlers(self) -> None:
|
||||
"""tracing_metadata only applies to tracer handlers, not other handlers."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
|
||||
received_metadata: list[dict[str, Any]] = []
|
||||
|
||||
class MetadataCapture(BaseCallbackHandler):
|
||||
def on_chain_start(self, *_args: Any, **kwargs: Any) -> None:
|
||||
received_metadata.append(dict(kwargs.get("metadata", {})))
|
||||
|
||||
capture = MetadataCapture()
|
||||
cm = CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer, capture],
|
||||
tracing_metadata={"tracer_only": "yes"},
|
||||
)
|
||||
|
||||
@RunnableLambda
|
||||
def my_func(x: int) -> int:
|
||||
return x
|
||||
|
||||
my_func.invoke(1, {"callbacks": cm})
|
||||
|
||||
# Non-tracer handler should NOT see tracing_metadata
|
||||
assert len(received_metadata) >= 1
|
||||
for md in received_metadata:
|
||||
assert "tracer_only" not in md
|
||||
|
||||
# But the tracer's posted runs SHOULD have it
|
||||
posts = _get_posts(tracer.client)
|
||||
assert len(posts) >= 1
|
||||
for post in posts:
|
||||
post_md = post.get("extra", {}).get("metadata", {})
|
||||
assert post_md["tracer_only"] == "yes"
|
||||
|
||||
def test_no_tracing_metadata_is_noop(self) -> None:
|
||||
"""Passing tracing_metadata=None does not alter tracer state."""
|
||||
tracer = _create_tracer_with_mocked_client()
|
||||
CallbackManager.configure(
|
||||
inheritable_callbacks=[tracer],
|
||||
tracing_metadata=None,
|
||||
)
|
||||
assert tracer.tracing_metadata is None
|
||||
|
||||
@@ -15,6 +15,7 @@ from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
from langchain_core.tracers.langchain import (
|
||||
LangChainTracer,
|
||||
_get_usage_metadata_from_generations,
|
||||
_patch_missing_metadata,
|
||||
)
|
||||
from langchain_core.tracers.schemas import Run
|
||||
|
||||
@@ -696,3 +697,148 @@ def test_on_chain_error_updates_when_not_defers_inputs() -> None:
|
||||
# Should call update (PATCH), not persist (POST) for normal inputs
|
||||
assert not persist_called
|
||||
assert update_called
|
||||
|
||||
|
||||
class TestPatchMissingMetadata:
|
||||
"""Tests for `_patch_missing_metadata` and tracer metadata behavior."""
|
||||
|
||||
@staticmethod
|
||||
def _make_tracer(
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
return LangChainTracer(client=client, metadata=metadata)
|
||||
|
||||
@staticmethod
|
||||
def _make_run(
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> Run:
|
||||
return Run(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
inputs={},
|
||||
run_type="chain",
|
||||
extra={"metadata": metadata or {}},
|
||||
)
|
||||
|
||||
def test_adds_metadata_when_run_has_none(self) -> None:
|
||||
"""Tracer metadata fills in when the run has no matching keys."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "service": "api"})
|
||||
run = self._make_run()
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata["env"] == "prod"
|
||||
assert run.metadata["service"] == "api"
|
||||
|
||||
def test_does_not_overwrite_existing_keys(self) -> None:
|
||||
"""Config metadata takes precedence over tracer metadata."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "service": "api"})
|
||||
run = self._make_run(metadata={"env": "staging"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata["env"] == "staging"
|
||||
assert run.metadata["service"] == "api"
|
||||
|
||||
def test_noop_when_tracer_has_no_metadata(self) -> None:
|
||||
"""No-op when the tracer has no metadata configured."""
|
||||
tracer = self._make_tracer(metadata=None)
|
||||
run = self._make_run(metadata={"existing": "value"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {"existing": "value"}
|
||||
|
||||
def test_noop_when_all_keys_already_present(self) -> None:
|
||||
"""No-op when every tracer key already exists in the run."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod"})
|
||||
run = self._make_run(metadata={"env": "dev"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {"env": "dev"}
|
||||
|
||||
def test_merges_disjoint_keys(self) -> None:
|
||||
"""Disjoint keys from tracer and config are all present after patching."""
|
||||
tracer = self._make_tracer(metadata={"tracer_key": "tracer_val"})
|
||||
run = self._make_run(metadata={"config_key": "config_val"})
|
||||
|
||||
_patch_missing_metadata(tracer, run)
|
||||
|
||||
assert run.metadata == {
|
||||
"tracer_key": "tracer_val",
|
||||
"config_key": "config_val",
|
||||
}
|
||||
|
||||
def test_persist_run_single_applies_tracer_metadata(self) -> None:
|
||||
"""End-to-end: `_persist_run_single` calls `_patch_missing_metadata`."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod"})
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||
tracer.on_chain_start(
|
||||
{"name": "test_chain"},
|
||||
{"input": "hello"},
|
||||
run_id=run_id,
|
||||
)
|
||||
run = tracer.run_map[str(run_id)]
|
||||
|
||||
with unittest.mock.patch.object(Run, "post"):
|
||||
tracer._persist_run_single(run)
|
||||
|
||||
assert run.metadata.get("env") == "prod"
|
||||
|
||||
def test_persist_run_single_config_metadata_wins(self) -> None:
|
||||
"""Config metadata is not overwritten by tracer metadata during persist."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod", "extra": "from_tracer"})
|
||||
run_id = UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160b")
|
||||
tracer.on_chain_start(
|
||||
{"name": "test_chain"},
|
||||
{"input": "hello"},
|
||||
run_id=run_id,
|
||||
metadata={"env": "staging"},
|
||||
)
|
||||
run = tracer.run_map[str(run_id)]
|
||||
|
||||
with unittest.mock.patch.object(Run, "post"):
|
||||
tracer._persist_run_single(run)
|
||||
|
||||
assert run.metadata["env"] == "staging"
|
||||
assert run.metadata["extra"] == "from_tracer"
|
||||
|
||||
|
||||
class TestSetDefaults:
|
||||
"""Tests for `LangChainTracer.set_defaults()`."""
|
||||
|
||||
@staticmethod
|
||||
def _make_tracer(
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> LangChainTracer:
|
||||
client = unittest.mock.MagicMock(spec=Client)
|
||||
client.tracing_queue = None
|
||||
return LangChainTracer(client=client, metadata=metadata)
|
||||
|
||||
def test_sets_metadata_when_none(self) -> None:
|
||||
"""Fills in metadata when tracer has no prior metadata."""
|
||||
tracer = self._make_tracer()
|
||||
tracer.set_defaults(metadata={"env": "prod"})
|
||||
assert tracer.tracing_metadata == {"env": "prod"}
|
||||
|
||||
def test_does_not_overwrite_existing_keys(self) -> None:
|
||||
"""Existing keys are preserved; only missing keys are added."""
|
||||
tracer = self._make_tracer(metadata={"env": "staging"})
|
||||
tracer.set_defaults(metadata={"env": "prod", "service": "api"})
|
||||
assert tracer.tracing_metadata == {"env": "staging", "service": "api"}
|
||||
|
||||
def test_noop_when_defaults_is_none(self) -> None:
|
||||
"""No-op when metadata=None is passed."""
|
||||
tracer = self._make_tracer(metadata={"env": "prod"})
|
||||
tracer.set_defaults(metadata=None)
|
||||
assert tracer.tracing_metadata == {"env": "prod"}
|
||||
|
||||
def test_multiple_calls_accumulate(self) -> None:
|
||||
"""Successive calls fill in disjoint keys."""
|
||||
tracer = self._make_tracer()
|
||||
tracer.set_defaults(metadata={"a": "1"})
|
||||
tracer.set_defaults(metadata={"b": "2", "a": "overwrite"})
|
||||
assert tracer.tracing_metadata == {"a": "1", "b": "2"}
|
||||
|
||||
Reference in New Issue
Block a user