core: Move json parsing in base chat model / output parser to bg thread (#24031)

- add version of AIMessageChunk.__add__ that can add many chunks,
instead of only 2
- In agenerate_from_stream merge and parse chunks in bg thread
- In output parse base classes do more work in bg threads where
appropriate

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Nuno Campos 2024-07-09 20:26:36 +01:00 committed by GitHub
parent 73966e693c
commit 160fc7f246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 191 additions and 164 deletions

View File

@ -94,13 +94,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
ChatResult: Chat result.
"""
generation: Optional[ChatGenerationChunk] = None
for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
generation = next(stream, None)
if generation:
generation += list(stream)
if generation is None:
raise ValueError("No generations found in stream.")
return ChatResult(
generations=[
ChatGeneration(
@ -123,21 +121,8 @@ async def agenerate_from_stream(
ChatResult: Chat result.
"""
generation: Optional[ChatGenerationChunk] = None
async for chunk in stream:
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
chunks = [chunk async for chunk in stream]
return await run_in_executor(None, generate_from_stream, iter(chunks))
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):

View File

@ -267,64 +267,69 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)
# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []
# Token usage
if self.usage_metadata or other.usage_metadata:
left: UsageMetadata = self.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
right: UsageMetadata = other.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = {
"input_tokens": left["input_tokens"] + right["input_tokens"],
"output_tokens": left["output_tokens"] + right["output_tokens"],
"total_tokens": left["total_tokens"] + right["total_tokens"],
}
else:
usage_metadata = None
return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=self.id,
)
return add_ai_message_chunks(self, other)
elif isinstance(other, (list, tuple)) and all(
isinstance(o, AIMessageChunk) for o in other
):
return add_ai_message_chunks(self, *other)
return super().__add__(other)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple AIMessageChunks together."""
if any(left.example != o.example for o in others):
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
content = merge_content(left.content, *(o.content for o in others))
additional_kwargs = merge_dicts(
left.additional_kwargs, *(o.additional_kwargs for o in others)
)
response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others)
)
# Merge tool call chunks
if raw_tool_calls := merge_lists(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
for other in others:
if other.usage_metadata is not None:
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata_["output_tokens"] += other.usage_metadata[
"output_tokens"
]
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
usage_metadata: Optional[UsageMetadata] = usage_metadata_
else:
usage_metadata = None
return left.__class__(
example=left.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=left.id,
)

View File

@ -111,7 +111,7 @@ class BaseMessage(Serializable):
def merge_content(
first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]],
*contents: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]:
"""Merge two message contents.
@ -122,31 +122,32 @@ def merge_content(
Returns:
The merged content.
"""
# If first chunk is a string
if isinstance(first_content, str):
# If the second chunk is also a string, then merge them naively
if isinstance(second_content, str):
return first_content + second_content
# If the second chunk is a list, add the first chunk to the start of the list
merged = first_content
for content in contents:
# If current is a string
if isinstance(merged, str):
# If the next chunk is also a string, then merge them naively
if isinstance(content, str):
merged = cast(str, merged) + content
# If the next chunk is a list, add the current to the start of the list
else:
merged = [merged] + content # type: ignore
elif isinstance(content, list):
# If both are lists
merged = merge_lists(cast(List, merged), content) # type: ignore
# If the first content is a list, and the second content is a string
else:
return_list: List[Union[str, Dict]] = [first_content]
return return_list + second_content
elif isinstance(second_content, List):
# If both are lists
merged_list = merge_lists(first_content, second_content)
return cast(list, merged_list)
# If the first content is a list, and the second content is a string
else:
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(first_content[-1], str):
return first_content[:-1] + [first_content[-1] + second_content]
# If second content is an empty string, treat as a no-op
elif second_content == "":
return first_content
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
# If the last element of the first content is a string
# Add the second content to the last element
if isinstance(merged[-1], str):
merged[-1] += content
# If second content is an empty string, treat as a no-op
elif content == "":
pass
else:
# Otherwise, add the second content as a new element of the list
merged.append(content)
return merged
class BaseMessageChunk(BaseMessage):
@ -195,6 +196,22 @@ class BaseMessageChunk(BaseMessage):
self.response_metadata, other.response_metadata
),
)
elif isinstance(other, list) and all(
isinstance(o, BaseMessageChunk) for o in other
):
content = merge_content(self.content, *(o.content for o in other))
additional_kwargs = merge_dicts(
self.additional_kwargs, *(o.additional_kwargs for o in other)
)
response_metadata = merge_dicts(
self.response_metadata, *(o.response_metadata for o in other)
)
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=content,
additional_kwargs=additional_kwargs,
response_metadata=response_metadata,
)
else:
raise TypeError(
'unsupported operand type(s) for +: "'

View File

@ -17,6 +17,7 @@ from langchain_core.outputs import (
Generation,
GenerationChunk,
)
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig
@ -37,9 +38,13 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
) -> AsyncIterator[T]:
async for chunk in input:
if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)])
yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
else:
yield self.parse_result([Generation(text=chunk)])
yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
)
def transform(
self,
@ -153,7 +158,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed:
if self.diff:
yield self._diff(prev_parsed, parsed)
yield await run_in_executor(None, self._diff, prev_parsed, parsed)
else:
yield parsed
prev_parsed = parsed

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, List, Literal
from typing import Any, Dict, List, Literal, Union
from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation
@ -88,7 +88,9 @@ class ChatGenerationChunk(ChatGeneration):
"""Get the namespace of the langchain object."""
return ["langchain", "schema", "output"]
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:
def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts(
self.generation_info or {},
@ -98,6 +100,17 @@ class ChatGenerationChunk(ChatGeneration):
message=self.message + other.message,
generation_info=generation_info or None,
)
elif isinstance(other, list) and all(
isinstance(x, ChatGenerationChunk) for x in other
):
generation_info = merge_dicts(
self.generation_info or {},
*[chunk.generation_info for chunk in other if chunk.generation_info],
)
return ChatGenerationChunk(
message=self.message + [chunk.message for chunk in other],
generation_info=generation_info or None,
)
else:
raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"

View File

@ -3,8 +3,8 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dicts, handling specific scenarios where a key exists in both
def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]:
"""Merge many dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary.
@ -16,57 +16,59 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}.
"""
merged = left.copy()
for right_k, right_v in right.items():
if right_k not in merged:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is None:
continue
elif type(merged[right_k]) is not type(right_v):
raise TypeError(
f'additional_kwargs["{right_k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[right_k], str):
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[right_k], list):
merged[right_k] = merge_lists(merged[right_k], right_v)
elif merged[right_k] == right_v:
continue
else:
raise TypeError(
f"Additional kwargs key {right_k} already exists in left dict and "
f"value has unsupported type {type(merged[right_k])}."
)
for right in others:
for right_k, right_v in right.items():
if right_k not in merged:
merged[right_k] = right_v
elif right_v is not None and merged[right_k] is None:
merged[right_k] = right_v
elif right_v is None:
continue
elif type(merged[right_k]) is not type(right_v):
raise TypeError(
f'additional_kwargs["{right_k}"] already exists in this message,'
" but with a different type."
)
elif isinstance(merged[right_k], str):
merged[right_k] += right_v
elif isinstance(merged[right_k], dict):
merged[right_k] = merge_dicts(merged[right_k], right_v)
elif isinstance(merged[right_k], list):
merged[right_k] = merge_lists(merged[right_k], right_v)
elif merged[right_k] == right_v:
continue
else:
raise TypeError(
f"Additional kwargs key {right_k} already exists in left dict and "
f"value has unsupported type {type(merged[right_k])}."
)
return merged
def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]:
"""Add two lists, handling None."""
if left is None and right is None:
return None
elif left is None or right is None:
return left or right
else:
merged = left.copy()
for e in right:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged)
if e_left["index"] == e["index"]
]
if to_merge:
# If a top-level "type" has been set for a chunk, it should no
# longer be overridden by the "type" field in future chunks.
if "type" in merged[to_merge[0]] and "type" in e:
e.pop("type")
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add many lists, handling None."""
merged = left.copy() if left is not None else None
for other in others:
if other is None:
continue
elif merged is None:
merged = other.copy()
else:
for e in other:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [
i
for i, e_left in enumerate(merged)
if e_left["index"] == e["index"]
]
if to_merge:
# If a top-level "type" has been set for a chunk, it should no
# longer be overridden by the "type" field in future chunks.
if "type" in merged[to_merge[0]] and "type" in e:
e.pop("type")
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
else:
merged.append(e)
else:
merged = merged + [e]
else:
merged = merged + [e]
return merged
merged.append(e)
return merged