mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
1
This commit is contained in:
@@ -70,6 +70,9 @@ class LangSmithParams(TypedDict, total=False):
|
||||
ls_stop: list[str] | None
|
||||
"""Stop words for generation."""
|
||||
|
||||
versions: dict[str, str]
|
||||
"""Package versions for tracing (e.g., `{"langchain-anthropic": "1.3.3"}`)."""
|
||||
|
||||
|
||||
@cache # Cache the tokenizer
|
||||
def get_tokenizer() -> Any:
|
||||
|
||||
@@ -63,7 +63,11 @@ from langchain_core.outputs.chat_generation import merge_chat_generation_chunks
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.rate_limiters import BaseRateLimiter
|
||||
from langchain_core.runnables import RunnableMap, RunnablePassthrough
|
||||
from langchain_core.runnables.config import ensure_config, run_in_executor
|
||||
from langchain_core.runnables.config import (
|
||||
_merge_metadata_dicts,
|
||||
ensure_config,
|
||||
run_in_executor,
|
||||
)
|
||||
from langchain_core.tracers._streaming import _StreamingCallbackHandler
|
||||
from langchain_core.utils.function_calling import (
|
||||
convert_to_json_schema,
|
||||
@@ -503,10 +507,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
config.get("metadata") or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
@@ -631,10 +635,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **kwargs, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
config.get("metadata") or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
@@ -895,10 +899,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
metadata or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks,
|
||||
@@ -1018,10 +1022,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC):
|
||||
|
||||
params = self._get_invocation_params(stop=stop, **kwargs)
|
||||
options = {"stop": stop, **ls_structured_output_format_dict}
|
||||
inheritable_metadata = {
|
||||
**(metadata or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
metadata or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
callbacks,
|
||||
|
||||
@@ -7,7 +7,15 @@ import asyncio
|
||||
# Cannot move uuid to TYPE_CHECKING as RunnableConfig is used in Pydantic models
|
||||
import uuid # noqa: TC003
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Generator, Iterable, Iterator, Sequence
|
||||
from collections.abc import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures import Executor, Future, ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextvars import Context, ContextVar, Token, copy_context
|
||||
@@ -354,6 +362,31 @@ def patch_config(
|
||||
return config
|
||||
|
||||
|
||||
def _merge_metadata_dicts(
|
||||
base: Mapping[str, Any], incoming: Mapping[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Merge two metadata dicts with one extra level of depth.
|
||||
|
||||
If both sides have a dict value for the same key, the inner dicts are merged
|
||||
(last-writer-wins within). Non-dict values use last-writer-wins at the
|
||||
top level.
|
||||
|
||||
Args:
|
||||
base: The base metadata dict.
|
||||
incoming: The incoming metadata dict to merge on top.
|
||||
|
||||
Returns:
|
||||
A new merged dict. Does not mutate inputs.
|
||||
"""
|
||||
merged = {**base}
|
||||
for key, value in incoming.items():
|
||||
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
|
||||
merged[key] = {**merged[key], **value}
|
||||
else:
|
||||
merged[key] = value
|
||||
return merged
|
||||
|
||||
|
||||
def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
"""Merge multiple configs into one.
|
||||
|
||||
@@ -369,10 +402,10 @@ def merge_configs(*configs: RunnableConfig | None) -> RunnableConfig:
|
||||
for config in (ensure_config(c) for c in configs if c is not None):
|
||||
for key in config:
|
||||
if key == "metadata":
|
||||
base["metadata"] = {
|
||||
**base.get("metadata", {}),
|
||||
**(config.get("metadata") or {}),
|
||||
}
|
||||
base["metadata"] = _merge_metadata_dicts(
|
||||
base.get("metadata", {}),
|
||||
config.get("metadata") or {},
|
||||
)
|
||||
elif key == "tags":
|
||||
base["tags"] = sorted(
|
||||
set(base.get("tags", []) + (config.get("tags") or [])),
|
||||
|
||||
@@ -16,6 +16,7 @@ from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHan
|
||||
from langchain_core.runnables import RunnableBinding, RunnablePassthrough
|
||||
from langchain_core.runnables.config import (
|
||||
RunnableConfig,
|
||||
_merge_metadata_dicts,
|
||||
_set_config_context,
|
||||
ensure_config,
|
||||
merge_configs,
|
||||
@@ -161,3 +162,56 @@ async def test_run_in_executor() -> None:
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await run_in_executor(None, raises_stop_iter)
|
||||
|
||||
|
||||
class TestMergeMetadataDicts:
|
||||
"""Tests for _merge_metadata_dicts deep-merge behavior."""
|
||||
|
||||
def test_deep_merge_preserves_both_nested_dicts(self) -> None:
|
||||
base = {"versions": {"langchain-core": "0.3.1"}, "user_id": "abc"}
|
||||
incoming = {"versions": {"langchain-anthropic": "1.3.3"}, "run": "x"}
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
assert result == {
|
||||
"versions": {
|
||||
"langchain-core": "0.3.1",
|
||||
"langchain-anthropic": "1.3.3",
|
||||
},
|
||||
"user_id": "abc",
|
||||
"run": "x",
|
||||
}
|
||||
|
||||
def test_last_writer_wins_within_nested_dicts(self) -> None:
|
||||
base = {"versions": {"pkg": "1.0"}}
|
||||
incoming = {"versions": {"pkg": "2.0"}}
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
assert result == {"versions": {"pkg": "2.0"}}
|
||||
|
||||
def test_non_dict_overwrites_dict(self) -> None:
|
||||
base = {"key": {"nested": "value"}}
|
||||
incoming = {"key": "flat"}
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
assert result == {"key": "flat"}
|
||||
|
||||
def test_dict_overwrites_non_dict(self) -> None:
|
||||
base = {"key": "flat"}
|
||||
incoming = {"key": {"nested": "value"}}
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
assert result == {"key": {"nested": "value"}}
|
||||
|
||||
def test_no_mutation_of_inputs(self) -> None:
|
||||
base = {"versions": {"a": "1"}}
|
||||
incoming = {"versions": {"b": "2"}}
|
||||
base_copy = {"versions": {"a": "1"}}
|
||||
incoming_copy = {"versions": {"b": "2"}}
|
||||
_merge_metadata_dicts(base, incoming)
|
||||
assert base == base_copy
|
||||
assert incoming == incoming_copy
|
||||
|
||||
def test_three_config_merge_accumulates(self) -> None:
|
||||
c1: RunnableConfig = {"metadata": {"versions": {"a": "1"}}}
|
||||
c2: RunnableConfig = {"metadata": {"versions": {"b": "2"}}}
|
||||
c3: RunnableConfig = {"metadata": {"versions": {"c": "3"}}}
|
||||
merged = merge_configs(c1, c2, c3)
|
||||
assert merged["metadata"] == {
|
||||
"versions": {"a": "1", "b": "2", "c": "3"},
|
||||
}
|
||||
|
||||
@@ -998,6 +998,7 @@ class ChatAnthropic(BaseChatModel):
|
||||
ls_params["ls_max_tokens"] = ls_max_tokens
|
||||
if ls_stop := stop or params.get("stop", None):
|
||||
ls_params["ls_stop"] = ls_stop
|
||||
ls_params["versions"] = {"langchain-anthropic": __version__}
|
||||
return ls_params
|
||||
|
||||
@model_validator(mode="before")
|
||||
|
||||
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, SecretStr, ValidationError
|
||||
from pytest import CaptureFixture, MonkeyPatch
|
||||
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
from langchain_anthropic._version import __version__
|
||||
from langchain_anthropic.chat_models import (
|
||||
_create_usage_metadata,
|
||||
_format_image,
|
||||
@@ -1677,12 +1678,21 @@ def test_anthropic_model_params() -> None:
|
||||
"ls_model_name": MODEL_NAME,
|
||||
"ls_max_tokens": 64000,
|
||||
"ls_temperature": None,
|
||||
"versions": {"langchain-anthropic": __version__},
|
||||
}
|
||||
|
||||
ls_params = llm._get_ls_params(model=MODEL_NAME)
|
||||
assert ls_params.get("ls_model_name") == MODEL_NAME
|
||||
|
||||
|
||||
def test_ls_params_versions_value() -> None:
|
||||
"""Test that _get_ls_params reports the correct langchain-anthropic version."""
|
||||
llm = ChatAnthropic(model=MODEL_NAME)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert "versions" in ls_params
|
||||
assert ls_params["versions"] == {"langchain-anthropic": __version__}
|
||||
|
||||
|
||||
def test_streaming_cache_token_reporting() -> None:
|
||||
"""Test that cache tokens are properly reported in streaming events."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
Reference in New Issue
Block a user