diff --git a/libs/core/langchain_core/tracers/langchain.py b/libs/core/langchain_core/tracers/langchain.py index 2a2b5ce0f62..b5563f94ae6 100644 --- a/libs/core/langchain_core/tracers/langchain.py +++ b/libs/core/langchain_core/tracers/langchain.py @@ -27,6 +27,8 @@ from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.schemas import Run if TYPE_CHECKING: + from collections.abc import Mapping + from langchain_core.messages import BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk @@ -124,6 +126,8 @@ class LangChainTracer(BaseTracer): project_name: str | None = None, client: Client | None = None, tags: list[str] | None = None, + *, + metadata: Mapping[str, str] | None = None, **kwargs: Any, ) -> None: """Initialize the LangChain tracer. @@ -139,6 +143,9 @@ class LangChainTracer(BaseTracer): tags: The tags. Defaults to an empty list. + metadata: Additional metadata to include if it isn't already in the run. + + Defaults to None. **kwargs: Additional keyword arguments. """ super().__init__(**kwargs) @@ -150,6 +157,7 @@ class LangChainTracer(BaseTracer): self.tags = tags or [] self.latest_run: Run | None = None self.run_has_token_event_map: dict[str, bool] = {} + self.metadata = metadata def _start_trace(self, run: Run) -> None: if self.project_name: @@ -263,6 +271,7 @@ class LangChainTracer(BaseTracer): try: run.extra["runtime"] = get_runtime_environment() run.tags = self._get_tags(run) + _patch_missing_metadata(self, run) if run.ls_client is not self.client: run.ls_client = self.client run.post() @@ -398,3 +407,12 @@ class LangChainTracer(BaseTracer): """Wait for the given futures to complete.""" if self.client is not None: self.client.flush() + + +def _patch_missing_metadata(self: LangChainTracer, run: Run) -> None: + if not self.metadata: + return + metadata = run.metadata + for k, v in self.metadata.items(): + if k not in metadata: + metadata[k] = v