core[minor]: generation info on msg (#18592)

related to #16403 #17188
This commit is contained in:
Bagatur
2024-03-11 21:43:17 -07:00
committed by GitHub
parent cda43c5a11
commit e0e688a277
12 changed files with 357 additions and 164 deletions

View File

@@ -15,6 +15,7 @@ from typing import (
List,
Optional,
Sequence,
Union,
cast,
)
@@ -240,6 +241,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
for chunk in self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
generation = chunk
@@ -317,6 +319,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
async for chunk in _stream_implementation(
messages, stop=stop, run_manager=run_manager, **kwargs
):
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
generation = chunk
@@ -586,38 +589,35 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._generate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return self._generate(messages, stop=stop, **kwargs)
if inspect.signature(self._generate).parameters.get("run_manager"):
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = llm_cache.lookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = self._generate(messages, stop=stop, **kwargs)
llm_cache.update(prompt, llm_string, result.generations)
return result
result = self._generate(messages, stop=stop, **kwargs)
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if check_cache and llm_cache:
llm_cache.update(prompt, llm_string, result.generations)
return result
async def _agenerate_with_cache(
self,
@@ -626,38 +626,34 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
new_arg_supported = inspect.signature(self._agenerate).parameters.get(
"run_manager"
)
disregard_cache = self.cache is not None and not self.cache
llm_cache = get_llm_cache()
if llm_cache is None or disregard_cache:
# This happens when langchain.cache is None, but self.cache is True
if self.cache is not None and self.cache:
check_cache = self.cache or self.cache is None
if check_cache:
if llm_cache:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
elif self.cache is None:
pass
else:
raise ValueError(
"Asked to cache, but no cache found at `langchain.cache`."
)
if new_arg_supported:
return await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
return await self._agenerate(messages, stop=stop, **kwargs)
if inspect.signature(self._agenerate).parameters.get("run_manager"):
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
llm_string = self._get_llm_string(stop=stop, **kwargs)
prompt = dumps(messages)
cache_val = await llm_cache.alookup(prompt, llm_string)
if isinstance(cache_val, list):
return ChatResult(generations=cache_val)
else:
if new_arg_supported:
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
else:
result = await self._agenerate(messages, stop=stop, **kwargs)
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
result = await self._agenerate(messages, stop=stop, **kwargs)
for generation in result.generations:
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
if check_cache and llm_cache:
await llm_cache.aupdate(prompt, llm_string, result.generations)
return result
@abstractmethod
def _generate(
@@ -852,3 +848,12 @@ class SimpleChatModel(BaseChatModel):
run_manager=run_manager.get_sync() if run_manager else None,
**kwargs,
)
def _gen_info_and_msg_metadata(
generation: Union[ChatGeneration, ChatGenerationChunk],
) -> dict:
return {
**(generation.generation_info or {}),
**generation.message.response_metadata,
}

View File

@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class AIMessage(BaseMessage):
@@ -49,9 +50,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.utils import get_bolded_text
from langchain_core.utils._merge import merge_dicts
from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
@@ -114,54 +115,6 @@ class BaseMessageChunk(BaseMessage):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "messages"]
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge additional_kwargs from another BaseMessageChunk into this one,
handling specific scenarios where a key exists in both dictionaries
but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
Example:
If left = {"function_call": {"arguments": None}} and
right = {"function_call": {"arguments": "{\n"}}
then, after merging, for the key "function_call",
the value from 'right' is used,
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif merged[k] is None and v:
merged[k] = v
elif v is None:
continue
elif merged[k] == v:
continue
elif type(merged[k]) != type(v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = self._merge_kwargs_dict(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k].copy()
for i, e in enumerate(v):
if isinstance(e, dict) and isinstance(e.get("index"), int):
i = e["index"]
if i < len(merged[k]):
merged[k][i] = self._merge_kwargs_dict(merged[k][i], e)
else:
merged[k] = merged[k] + [e]
else:
raise TypeError(
f"Additional kwargs key {k} already exists in this message."
)
return merged
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, BaseMessageChunk):
# If both are (subclasses of) BaseMessageChunk,
@@ -170,9 +123,12 @@ class BaseMessageChunk(BaseMessage):
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
else:
raise TypeError(

View File

@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class ChatMessage(BaseMessage):
@@ -47,17 +48,23 @@ class ChatMessageChunk(ChatMessage, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
role=self.role,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
else:
return super().__add__(other)

View File

@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class FunctionMessage(BaseMessage):
@@ -47,9 +48,12 @@ class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
return self.__class__(
name=self.name,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@@ -5,6 +5,7 @@ from langchain_core.messages.base import (
BaseMessageChunk,
merge_content,
)
from langchain_core.utils._merge import merge_dicts
class ToolMessage(BaseMessage):
@@ -47,9 +48,12 @@ class ToolMessageChunk(ToolMessage, BaseMessageChunk):
return self.__class__(
tool_call_id=self.tool_call_id,
content=merge_content(self.content, other.content),
additional_kwargs=self._merge_kwargs_dict(
additional_kwargs=merge_dicts(
self.additional_kwargs, other.additional_kwargs
),
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
)
return super().__add__(other)

View File

@@ -16,27 +16,44 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for k, v in right.items():
if k not in merged:
merged[k] = v
elif v is not None and merged[k] is None:
merged[k] = v
elif v is None or merged[k] == v:
for right_k, right_v in right.items():
if right_k not in merged:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is None:
continue
elif type(merged[k]) != type(v):
elif type(merged[right_k]) != type(right_v):
raise TypeError(
f'additional_kwargs["{k}"] already exists in this message,'
f'additional_kwargs["{right_k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[k], str):
merged[k] += v
elif isinstance(merged[k], dict):
merged[k] = merge_dicts(merged[k], v)
elif isinstance(merged[k], list):
merged[k] = merged[k] + v
elif isinstance(merged[right_k], str):
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[right_k], list):
merged[right_k] = merged[right_k].copy()
for e in right_v:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged[right_k])
if e_left["index"] == e["index"]
]
if to_merge:
merged[right_k][to_merge[0]] = merge_dicts(
merged[right_k][to_merge[0]], e
)
else:
merged[right_k] = merged[right_k] + [e]
else:
merged[right_k] = merged[right_k] + [e]
elif merged[right_k] == right_v:
continue
else:
raise TypeError(
f"Additional kwargs key {k} already exists in left dict and value has "
f"unsupported type {type(merged[k])}."
f"Additional kwargs key {right_k} already exists in left dict and "
f"value has unsupported type {type(merged[right_k])}."
)
return merged

View File

@@ -48,9 +48,9 @@ def test_check_package_version(
({"a": 1.5}, {"a": 1.5}, {"a": 1.5}),
({"a": True}, {"a": True}, {"a": True}),
({"a": False}, {"a": False}, {"a": False}),
({"a": "txt"}, {"a": "txt"}, {"a": "txt"}),
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2]}),
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txt"}}),
({"a": "txt"}, {"a": "txt"}, {"a": "txttxt"}),
({"a": [1, 2]}, {"a": [1, 2]}, {"a": [1, 2, 1, 2]}),
({"a": {"b": "txt"}}, {"a": {"b": "txt"}}, {"a": {"b": "txttxt"}}),
# Merge strings.
({"a": "one"}, {"a": "two"}, {"a": "onetwo"}),
# Merge dicts.
@@ -89,6 +89,17 @@ def test_check_package_version(
),
),
),
# 'index' keyword has special handling
(
{"a": [{"index": 0, "b": "{"}]},
{"a": [{"index": 0, "b": "f"}]},
{"a": [{"index": 0, "b": "{f"}]},
),
(
{"a": [{"idx": 0, "b": "{"}]},
{"a": [{"idx": 0, "b": "f"}]},
{"a": [{"idx": 0, "b": "{"}, {"idx": 0, "b": "f"}]},
),
),
)
def test_merge_dicts(