mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
cr
This commit is contained in:
@@ -71,7 +71,12 @@ class LangSmithParams(TypedDict, total=False):
|
||||
"""Stop words for generation."""
|
||||
|
||||
versions: dict[str, str]
|
||||
"""Package versions for tracing (e.g., `{"langchain-anthropic": "1.3.3"}`)."""
|
||||
"""Package versions for tracing (e.g., `{"langchain-anthropic": "1.3.3"}`).
|
||||
|
||||
Maps partner package names to their installed versions. Deep-merged with
|
||||
existing metadata so that versions from multiple integration layers are
|
||||
preserved rather than overwritten.
|
||||
"""
|
||||
|
||||
|
||||
@cache # Cache the tokenizer
|
||||
|
||||
@@ -54,7 +54,10 @@ from langchain_core.messages import (
|
||||
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
|
||||
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
|
||||
from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list
|
||||
from langchain_core.runnables.config import run_in_executor
|
||||
from langchain_core.runnables.config import (
|
||||
_merge_metadata_dicts,
|
||||
run_in_executor,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import uuid
|
||||
@@ -523,10 +526,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
config.get("metadata") or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
callback_manager = CallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
@@ -593,10 +596,10 @@ class BaseLLM(BaseLanguageModel[str], ABC):
|
||||
params["stop"] = stop
|
||||
params = {**params, **kwargs}
|
||||
options = {"stop": stop}
|
||||
inheritable_metadata = {
|
||||
**(config.get("metadata") or {}),
|
||||
**self._get_ls_params(stop=stop, **kwargs),
|
||||
}
|
||||
inheritable_metadata = _merge_metadata_dicts(
|
||||
config.get("metadata") or {},
|
||||
self._get_ls_params(stop=stop, **kwargs),
|
||||
)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
config.get("callbacks"),
|
||||
self.callbacks,
|
||||
|
||||
@@ -367,23 +367,40 @@ def _merge_metadata_dicts(
|
||||
) -> dict[str, Any]:
|
||||
"""Merge two metadata dicts with one extra level of depth.
|
||||
|
||||
If both sides have a dict value for the same key, the inner dicts are merged
|
||||
(last-writer-wins within). Non-dict values use last-writer-wins at the
|
||||
top level.
|
||||
If both sides have a `Mapping` value for the same key, the inner mappings
|
||||
are merged (last-writer-wins within). Non-mapping values use
|
||||
last-writer-wins at the top level. Only one level of depth is merged; values
|
||||
nested more deeply are not recursively merged.
|
||||
|
||||
Args:
|
||||
base: The base metadata dict.
|
||||
incoming: The incoming metadata dict to merge on top.
|
||||
|
||||
Values here are kept unless overridden by `incoming`.
|
||||
incoming: The metadata dict to merge on top.
|
||||
|
||||
Its values take precedence on conflict.
|
||||
|
||||
Returns:
|
||||
A new merged dict. Does not mutate inputs.
|
||||
A new merged dict.
|
||||
|
||||
Inputs are not mutated. The returned dict performs shallow copies at
|
||||
the top level and one level deep; mutable values nested beyond that
|
||||
depth are shared references with the originals.
|
||||
"""
|
||||
merged = {**base}
|
||||
for key, value in incoming.items():
|
||||
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
|
||||
if (
|
||||
key in merged
|
||||
and isinstance(merged[key], Mapping)
|
||||
and isinstance(value, Mapping)
|
||||
):
|
||||
merged[key] = {**merged[key], **value}
|
||||
else:
|
||||
merged[key] = value
|
||||
# Ensure non-overlapping nested mappings are also copies, not shared refs.
|
||||
for key in base:
|
||||
if key not in incoming and isinstance(merged[key], Mapping):
|
||||
merged[key] = {**merged[key]}
|
||||
return merged
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from langchain_core.language_models import (
|
||||
ParrotFakeChatModel,
|
||||
)
|
||||
from langchain_core.language_models._utils import _normalize_messages
|
||||
from langchain_core.language_models.base import LangSmithParams
|
||||
from langchain_core.language_models.chat_models import _generate_response_from_error
|
||||
from langchain_core.language_models.fake_chat_models import (
|
||||
FakeListChatModelError,
|
||||
@@ -45,6 +46,7 @@ from tests.unit_tests.stubs import _any_id_ai_message, _any_id_ai_message_chunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.outputs.llm_result import LLMResult
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
|
||||
|
||||
def _content_blocks_equal_ignore_id(
|
||||
@@ -1213,6 +1215,53 @@ def test_get_ls_params() -> None:
|
||||
assert ls_params["ls_stop"] == ["stop"]
|
||||
|
||||
|
||||
class _VersionedFakeModel(FakeListChatModel):
|
||||
"""Fake model that reports a versions dict in `ls_params`."""
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: list[str] | None = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["versions"] = {"langchain-fake": "0.1.0"}
|
||||
return params
|
||||
|
||||
|
||||
def test_user_versions_metadata_survives_merge() -> None:
|
||||
"""User-provided versions metadata should be deep-merged with model versions.
|
||||
|
||||
Regression test: if the merge in `BaseChatModel` reverts to a flat dict
|
||||
spread, user-provided versions would be silently overwritten by the
|
||||
model's versions.
|
||||
"""
|
||||
llm = _VersionedFakeModel(responses=["hello"])
|
||||
user_config: RunnableConfig = {"metadata": {"versions": {"my-app": "2.0"}}}
|
||||
|
||||
with collect_runs() as cb:
|
||||
llm.invoke([HumanMessage(content="hi")], config=user_config)
|
||||
assert len(cb.traced_runs) == 1
|
||||
run_metadata = cb.traced_runs[0].extra["metadata"]
|
||||
# Both user-provided and model-provided versions must be present.
|
||||
assert run_metadata["versions"] == {
|
||||
"my-app": "2.0",
|
||||
"langchain-fake": "0.1.0",
|
||||
}
|
||||
|
||||
|
||||
async def test_user_versions_metadata_survives_merge_async() -> None:
|
||||
"""Async variant: user-provided versions metadata deep-merged with model's."""
|
||||
llm = _VersionedFakeModel(responses=["hello"])
|
||||
user_config: RunnableConfig = {"metadata": {"versions": {"my-app": "2.0"}}}
|
||||
|
||||
with collect_runs() as cb:
|
||||
await llm.ainvoke([HumanMessage(content="hi")], config=user_config)
|
||||
assert len(cb.traced_runs) == 1
|
||||
run_metadata = cb.traced_runs[0].extra["metadata"]
|
||||
assert run_metadata["versions"] == {
|
||||
"my-app": "2.0",
|
||||
"langchain-fake": "0.1.0",
|
||||
}
|
||||
|
||||
|
||||
def test_model_profiles() -> None:
|
||||
model = GenericFakeChatModel(messages=iter([]))
|
||||
assert model.profile is None
|
||||
|
||||
@@ -203,9 +203,38 @@ class TestMergeMetadataDicts:
|
||||
incoming = {"versions": {"b": "2"}}
|
||||
base_copy = {"versions": {"a": "1"}}
|
||||
incoming_copy = {"versions": {"b": "2"}}
|
||||
_merge_metadata_dicts(base, incoming)
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
assert base == base_copy
|
||||
assert incoming == incoming_copy
|
||||
# Returned nested dicts should not share identity with originals.
|
||||
assert result["versions"] is not base["versions"]
|
||||
assert result["versions"] is not incoming["versions"]
|
||||
|
||||
def test_non_overlapping_nested_dict_is_copied(self) -> None:
|
||||
base = {"versions": {"a": "1"}, "extras": {"x": "y"}}
|
||||
incoming = {"versions": {"b": "2"}}
|
||||
result = _merge_metadata_dicts(base, incoming)
|
||||
# "extras" was not in incoming — result should still be a copy.
|
||||
assert result["extras"] is not base["extras"]
|
||||
assert result["extras"] == {"x": "y"}
|
||||
|
||||
def test_both_empty(self) -> None:
|
||||
assert _merge_metadata_dicts({}, {}) == {}
|
||||
|
||||
def test_empty_base(self) -> None:
|
||||
result = _merge_metadata_dicts({}, {"versions": {"pkg": "1.0"}})
|
||||
assert result == {"versions": {"pkg": "1.0"}}
|
||||
|
||||
def test_empty_incoming(self) -> None:
|
||||
result = _merge_metadata_dicts({"versions": {"pkg": "1.0"}}, {})
|
||||
assert result == {"versions": {"pkg": "1.0"}}
|
||||
|
||||
def test_merge_configs_with_none_metadata(self) -> None:
|
||||
merged = merge_configs(
|
||||
cast("RunnableConfig", {"metadata": None}),
|
||||
{"metadata": {"versions": {"a": "1"}}},
|
||||
)
|
||||
assert merged["metadata"] == {"versions": {"a": "1"}}
|
||||
|
||||
def test_three_config_merge_accumulates(self) -> None:
|
||||
c1: RunnableConfig = {"metadata": {"versions": {"a": "1"}}}
|
||||
|
||||
@@ -1685,14 +1685,6 @@ def test_anthropic_model_params() -> None:
|
||||
assert ls_params.get("ls_model_name") == MODEL_NAME
|
||||
|
||||
|
||||
def test_ls_params_versions_value() -> None:
|
||||
"""Test that _get_ls_params reports the correct langchain-anthropic version."""
|
||||
llm = ChatAnthropic(model=MODEL_NAME)
|
||||
ls_params = llm._get_ls_params()
|
||||
assert "versions" in ls_params
|
||||
assert ls_params["versions"] == {"langchain-anthropic": __version__}
|
||||
|
||||
|
||||
def test_streaming_cache_token_reporting() -> None:
|
||||
"""Test that cache tokens are properly reported in streaming events."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
Reference in New Issue
Block a user