mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
@@ -15,6 +15,7 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
@@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
for chunk in self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
@@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
async for chunk in _stream_implementation(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
|
||||
yield chunk.message
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
@@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._generate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return self._generate(messages, stop=stop, **kwargs)
|
||||
if inspect.signature(self._generate).parameters.get("run_manager"):
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = llm_cache.lookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = self._generate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
result = self._generate(messages, stop=stop, **kwargs)
|
||||
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if check_cache and llm_cache:
|
||||
llm_cache.update(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
async def _agenerate_with_cache(
|
||||
self,
|
||||
@@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
|
||||
"run_manager"
|
||||
)
|
||||
disregard_cache = self.cache is not None and not self.cache
|
||||
llm_cache = get_llm_cache()
|
||||
if llm_cache is None or disregard_cache:
|
||||
# This happens when langchain.cache is None, but self.cache is True
|
||||
if self.cache is not None and self.cache:
|
||||
check_cache = self.cache or self.cache is None
|
||||
if check_cache:
|
||||
if llm_cache:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = await llm_cache.alookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
elif self.cache is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
"Asked to cache, but no cache found at `langchain.cache`."
|
||||
)
|
||||
if new_arg_supported:
|
||||
return await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
return await self._agenerate(messages, stop=stop, **kwargs)
|
||||
if inspect.signature(self._agenerate).parameters.get("run_manager"):
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
llm_string = self._get_llm_string(stop=stop, **kwargs)
|
||||
prompt = dumps(messages)
|
||||
cache_val = await llm_cache.alookup(prompt, llm_string)
|
||||
if isinstance(cache_val, list):
|
||||
return ChatResult(generations=cache_val)
|
||||
else:
|
||||
if new_arg_supported:
|
||||
result = await self._agenerate(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
else:
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
||||
return result
|
||||
result = await self._agenerate(messages, stop=stop, **kwargs)
|
||||
for generation in result.generations:
|
||||
generation.message.response_metadata = _gen_info_and_msg_metadata(
|
||||
generation
|
||||
)
|
||||
if check_cache and llm_cache:
|
||||
await llm_cache.aupdate(prompt, llm_string, result.generations)
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _generate(
|
||||
@@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel):
|
||||
run_manager=run_manager.get_sync() if run_manager else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _gen_info_and_msg_metadata(
|
||||
generation: Union[ChatGeneration, ChatGenerationChunk],
|
||||
) -> dict:
|
||||
return {
|
||||
**(generation.generation_info or {}),
|
||||
**generation.message.response_metadata,
|
||||
}
|
||||
|
@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class AIMessage(BaseMessage):
|
||||
@@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
example=self.example,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||
from langchain_core.load.serializable import Serializable
|
||||
from langchain_core.pydantic_v1 import Extra, Field
|
||||
from langchain_core.utils import get_bolded_text
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
from langchain_core.utils.interactive_env import is_interactive_env
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
def _merge_kwargs_dict(
|
||||
self, left: Dict[str, Any], right: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge additional_kwargs from another BaseMessageChunk into this one,
|
||||
handling specific scenarios where a key exists in both dictionaries
|
||||
but has a value of None in 'left'. In such cases, the method uses the
|
||||
value from 'right' for that key in the merged dictionary.
|
||||
Example:
|
||||
If left = {"function_call": {"arguments": None}} and
|
||||
right = {"function_call": {"arguments": "{\n"}}
|
||||
then, after merging, for the key "function_call",
|
||||
the value from 'right' is used,
|
||||
resulting in merged = {"function_call": {"arguments": "{\n"}}.
|
||||
"""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif merged[k] is None and v:
|
||||
merged[k] = v
|
||||
elif v is None:
|
||||
continue
|
||||
elif merged[k] == v:
|
||||
continue
|
||||
elif type(merged[k]) != type(v):
|
||||
raise TypeError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = self._merge_kwargs_dict(merged[k], v)
|
||||
elif isinstance(merged[k], list):
|
||||
merged[k] = merged[k].copy()
|
||||
for i, e in enumerate(v):
|
||||
if isinstance(e, dict) and isinstance(e.get("index"), int):
|
||||
i = e["index"]
|
||||
if i < len(merged[k]):
|
||||
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
|
||||
else:
|
||||
merged[k] = merged[k] + [e]
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Additional kwargs key {k} already exists in this message."
|
||||
)
|
||||
return merged
|
||||
|
||||
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
|
||||
if isinstance(other, BaseMessageChunk):
|
||||
# If both are (subclasses of) BaseMessageChunk,
|
||||
@@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage):
|
||||
return self.__class__( # type: ignore[call-arg]
|
||||
id=self.id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
|
@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class ChatMessage(BaseMessage):
|
||||
@@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
elif isinstance(other, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
role=self.role,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
else:
|
||||
return super().__add__(other)
|
||||
|
@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class FunctionMessage(BaseMessage):
|
||||
@@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
name=self.name,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
|
||||
BaseMessageChunk,
|
||||
merge_content,
|
||||
)
|
||||
from langchain_core.utils._merge import merge_dicts
|
||||
|
||||
|
||||
class ToolMessage(BaseMessage):
|
||||
@@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
|
||||
return self.__class__(
|
||||
tool_call_id=self.tool_call_id,
|
||||
content=merge_content(self.content, other.content),
|
||||
additional_kwargs=self._merge_kwargs_dict(
|
||||
additional_kwargs=merge_dicts(
|
||||
self.additional_kwargs, other.additional_kwargs
|
||||
),
|
||||
response_metadata=merge_dicts(
|
||||
self.response_metadata, other.response_metadata
|
||||
),
|
||||
)
|
||||
|
||||
return super().__add__(other)
|
||||
|
@@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
|
||||
resulting in merged = {"function_call": {"arguments": "{\n"}}.
|
||||
"""
|
||||
merged = left.copy()
|
||||
for k, v in right.items():
|
||||
if k not in merged:
|
||||
merged[k] = v
|
||||
elif v is not None and merged[k] is None:
|
||||
merged[k] = v
|
||||
elif v is None or merged[k] == v:
|
||||
for right_k, right_v in right.items():
|
||||
if right_k not in merged:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is not None and merged[right_k] is None:
|
||||
merged[right_k] = right_v
|
||||
elif right_v is None:
|
||||
continue
|
||||
elif type(merged[k]) != type(v):
|
||||
elif type(merged[right_k]) != type(right_v):
|
||||
raise TypeError(
|
||||
f'additional_kwargs["{k}"] already exists in this message,'
|
||||
f'additional_kwargs["{right_k}"] already exists in this message,'
|
||||
" but with a different type."
|
||||
)
|
||||
elif isinstance(merged[k], str):
|
||||
merged[k] += v
|
||||
elif isinstance(merged[k], dict):
|
||||
merged[k] = merge_dicts(merged[k], v)
|
||||
elif isinstance(merged[k], list):
|
||||
merged[k] = merged[k] + v
|
||||
elif isinstance(merged[right_k], str):
|
||||
merged[right_k] += right_v
|
||||
elif isinstance(merged[right_k], dict):
|
||||
merged[right_k] = merge_dicts(merged[right_k], right_v)
|
||||
elif isinstance(merged[right_k], list):
|
||||
merged[right_k] = merged[right_k].copy()
|
||||
for e in right_v:
|
||||
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
|
||||
to_merge = [
|
||||
i
|
||||
for i, e_left in enumerate(merged[right_k])
|
||||
if e_left["index"] == e["index"]
|
||||
]
|
||||
if to_merge:
|
||||
merged[right_k][to_merge[0]] = merge_dicts(
|
||||
merged[right_k][to_merge[0]], e
|
||||
)
|
||||
else:
|
||||
merged[right_k] = merged[right_k] + [e]
|
||||
else:
|
||||
merged[right_k] = merged[right_k] + [e]
|
||||
elif merged[right_k] == right_v:
|
||||
continue
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Additional kwargs key {k} already exists in left dict and value has "
|
||||
f"unsupported type {type(merged[k])}."
|
||||
f"Additional kwargs key {right_k} already exists in left dict and "
|
||||
f"value has unsupported type {type(merged[right_k])}."
|
||||
)
|
||||
return merged
|
||||
|
@@ -48,9 +48,9 @@ def test_check_package_version(
|
||||
({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
|
||||
({"a": True}, {"a": True}, {"a": True}),
|
||||
({"a": False}, {"a": False}, {"a": False}),
|
||||
({"a": "txt"}, {"a": "txt"}, {"a": "txt"}),
|
||||
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}),
|
||||
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}),
|
||||
({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
|
||||
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
|
||||
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
|
||||
# Merge strings.
|
||||
({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
|
||||
# Merge dicts.
|
||||
@@ -89,6 +89,17 @@ def test_check_package_version(
|
||||
),
|
||||
),
|
||||
),
|
||||
# 'index' keyword has special handling
|
||||
(
|
||||
{"a": [{"index": 0, "b": "{"}]},
|
||||
{"a": [{"index": 0, "b": "f"}]},
|
||||
{"a": [{"index": 0, "b": "{f"}]},
|
||||
),
|
||||
(
|
||||
{"a": [{"idx": 0, "b": "{"}]},
|
||||
{"a": [{"idx": 0, "b": "f"}]},
|
||||
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
|
||||
),
|
||||
),
|
||||
)
|
||||
def test_merge_dicts(
|
||||
|
Reference in New Issue
Block a user