This commit is contained in:
Mason Daugherty
2026-02-17 21:49:55 -05:00
parent b004103721
commit ff2ecd0b7e
6 changed files with 127 additions and 22 deletions

View File

@@ -70,6 +70,9 @@ class LangSmithParams(TypedDict, total=False):
ls_stop: list[str] | None
"""Stop words for generation."""
versions: dict[str, str]
"""Package versions for tracing (e.g., `{"langchain-anthropic": "1.3.3"}`)."""
@cache # Cache the tokenizer
def get_tokenizer() -> Any:

View File

@@ -63,7 +63,11 @@ from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.rate_limiters import BaseRateLimiter
from langchain_core.runnables import RunnableMap, RunnablePassthrough
from langchain_core.runnables.config import ensure_config, run_in_executor
from langchain_core.runnables.config import (
_merge_metadata_dicts,
ensure_config,
run_in_executor,
)
from langchain_core.tracers._streaming import _StreamingCallbackHandler
from langchain_core.utils.function_calling import (
convert_to_json_schema,
@@ -503,10 +507,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs, **ls_structured_output_format_dict}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
inheritable_metadata = _merge_metadata_dicts(
config.get("metadata") or {},
self._get_ls_params(stop=stop, **kwargs),
)
callback_manager = CallbackManager.configure(
config.get("callbacks"),
self.callbacks,
@@ -631,10 +635,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **kwargs, **ls_structured_output_format_dict}
inheritable_metadata = {
**(config.get("metadata") or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
inheritable_metadata = _merge_metadata_dicts(
config.get("metadata") or {},
self._get_ls_params(stop=stop, **kwargs),
)
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
@@ -895,10 +899,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **ls_structured_output_format_dict}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
inheritable_metadata = _merge_metadata_dicts(
metadata or {},
self._get_ls_params(stop=stop, **kwargs),
)
callback_manager = CallbackManager.configure(
callbacks,
@@ -1018,10 +1022,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
params = self._get_invocation_params(stop=stop, **kwargs)
options = {"stop": stop, **ls_structured_output_format_dict}
inheritable_metadata = {
**(metadata or {}),
**self._get_ls_params(stop=stop, **kwargs),
}
inheritable_metadata = _merge_metadata_dicts(
metadata or {},
self._get_ls_params(stop=stop, **kwargs),
)
callback_manager = AsyncCallbackManager.configure(
callbacks,

View File

@@ -7,7 +7,15 @@ import asyncio
# Cannot move uuid to TYPE_CHECKING as RunnableConfig is used in Pydantic models
import uuid # noqa: TC003
import warnings
from collections.abc import Awaitable, Callable, Generator, Iterable, Iterator, Sequence
from collections.abc import (
Awaitable,
Callable,
Generator,
Iterable,
Iterator,
Mapping,
Sequence,
)
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import Context, ContextVar, Token, copy_context
@@ -354,6 +362,31 @@ def patch_config(
return config
def _merge_metadata_dicts(
base: Mapping[str, Any], incoming: Mapping[str, Any]
) -> dict[str, Any]:
"""Merge two metadata dicts with one extra level of depth.
If both sides have a dict value for the same key, the inner dicts are merged
(last-writer-wins within). Non-dict values use last-writer-wins at the
top level.
Args:
base: The base metadata dict.
incoming: The incoming metadata dict to merge on top.
Returns:
A new merged dict. Does not mutate inputs.
"""
merged = {**base}
for key, value in incoming.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
merged[key] = {**merged[key], **value}
else:
merged[key] = value
return merged
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
"""Merge multiple configs into one.
@@ -369,10 +402,10 @@ def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
for config in (ensure_config(c) for c in configs if c is not None):
for key in config:
if key == "metadata":
base["metadata"] = {
**base.get("metadata", {}),
**(config.get("metadata") or {}),
}
base["metadata"] = _merge_metadata_dicts(
base.get("metadata", {}),
config.get("metadata") or {},
)
elif key == "tags":
base["tags"] = sorted(
set(base.get("tags", []) + (config.get("tags") or [])),

View File

@@ -16,6 +16,7 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
from langchain_core.runnables.config import (
RunnableConfig,
_merge_metadata_dicts,
_set_config_context,
ensure_config,
merge_configs,
@@ -161,3 +162,56 @@ async def test_run_in_executor() -> None:
with pytest.raises(RuntimeError):
await run_in_executor(None, raises_stop_iter)
class TestMergeMetadataDicts:
"""Tests for _merge_metadata_dicts deep-merge behavior."""
def test_deep_merge_preserves_both_nested_dicts(self) -> None:
base = {"versions": {"langchain-core": "0.3.1"}, "user_id": "abc"}
incoming = {"versions": {"langchain-anthropic": "1.3.3"}, "run": "x"}
result = _merge_metadata_dicts(base, incoming)
assert result == {
"versions": {
"langchain-core": "0.3.1",
"langchain-anthropic": "1.3.3",
},
"user_id": "abc",
"run": "x",
}
def test_last_writer_wins_within_nested_dicts(self) -> None:
base = {"versions": {"pkg": "1.0"}}
incoming = {"versions": {"pkg": "2.0"}}
result = _merge_metadata_dicts(base, incoming)
assert result == {"versions": {"pkg": "2.0"}}
def test_non_dict_overwrites_dict(self) -> None:
base = {"key": {"nested": "value"}}
incoming = {"key": "flat"}
result = _merge_metadata_dicts(base, incoming)
assert result == {"key": "flat"}
def test_dict_overwrites_non_dict(self) -> None:
base = {"key": "flat"}
incoming = {"key": {"nested": "value"}}
result = _merge_metadata_dicts(base, incoming)
assert result == {"key": {"nested": "value"}}
def test_no_mutation_of_inputs(self) -> None:
base = {"versions": {"a": "1"}}
incoming = {"versions": {"b": "2"}}
base_copy = {"versions": {"a": "1"}}
incoming_copy = {"versions": {"b": "2"}}
_merge_metadata_dicts(base, incoming)
assert base == base_copy
assert incoming == incoming_copy
def test_three_config_merge_accumulates(self) -> None:
c1: RunnableConfig = {"metadata": {"versions": {"a": "1"}}}
c2: RunnableConfig = {"metadata": {"versions": {"b": "2"}}}
c3: RunnableConfig = {"metadata": {"versions": {"c": "3"}}}
merged = merge_configs(c1, c2, c3)
assert merged["metadata"] == {
"versions": {"a": "1", "b": "2", "c": "3"},
}

View File

@@ -998,6 +998,7 @@ class ChatAnthropic(BaseChatModel):
ls_params["ls_max_tokens"] = ls_max_tokens
if ls_stop := stop or params.get("stop", None):
ls_params["ls_stop"] = ls_stop
ls_params["versions"] = {"langchain-anthropic": __version__}
return ls_params
@model_validator(mode="before")

View File

@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, SecretStr, ValidationError
from pytest import CaptureFixture, MonkeyPatch
from langchain_anthropic import ChatAnthropic
from langchain_anthropic._version import __version__
from langchain_anthropic.chat_models import (
_create_usage_metadata,
_format_image,
@@ -1677,12 +1678,21 @@ def test_anthropic_model_params() -> None:
"ls_model_name": MODEL_NAME,
"ls_max_tokens": 64000,
"ls_temperature": None,
"versions": {"langchain-anthropic": __version__},
}
ls_params = llm._get_ls_params(model=MODEL_NAME)
assert ls_params.get("ls_model_name") == MODEL_NAME
def test_ls_params_versions_value() -> None:
"""Test that _get_ls_params reports the correct langchain-anthropic version."""
llm = ChatAnthropic(model=MODEL_NAME)
ls_params = llm._get_ls_params()
assert "versions" in ls_params
assert ls_params["versions"] == {"langchain-anthropic": __version__}
def test_streaming_cache_token_reporting() -> None:
"""Test that cache tokens are properly reported in streaming events."""
from unittest.mock import MagicMock