core: Move json parsing in base chat model / output parser to bg thread (#24031)

- add version of AIMessageChunk.__add__ that can add many chunks,
instead of only 2
- In agenerate_from_stream merge and parse chunks in bg thread
- In output parse base classes do more work in bg threads where
appropriate

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Nuno Campos 2024-07-09 20:26:36 +01:00 committed by GitHub
parent 73966e693c
commit 160fc7f246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 191 additions and 164 deletions

View File

@ -94,13 +94,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
ChatResult: Chat result. ChatResult: Chat result.
""" """
generation: Optional[ChatGenerationChunk] = None generation = next(stream, None)
for chunk in stream: if generation:
generation += list(stream)
if generation is None: if generation is None:
generation = chunk raise ValueError("No generations found in stream.")
else:
generation += chunk
assert generation is not None
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration( ChatGeneration(
@ -123,21 +121,8 @@ async def agenerate_from_stream(
ChatResult: Chat result. ChatResult: Chat result.
""" """
generation: Optional[ChatGenerationChunk] = None chunks = [chunk async for chunk in stream]
async for chunk in stream: return await run_in_executor(None, generate_from_stream, iter(chunks))
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):

View File

@ -267,26 +267,35 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk): if isinstance(other, AIMessageChunk):
if self.example != other.example: return add_ai_message_chunks(self, other)
elif isinstance(other, (list, tuple)) and all(
isinstance(o, AIMessageChunk) for o in other
):
return add_ai_message_chunks(self, *other)
return super().__add__(other)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple AIMessageChunks together."""
if any(left.example != o.example for o in others):
raise ValueError( raise ValueError(
"Cannot concatenate AIMessageChunks with different example values." "Cannot concatenate AIMessageChunks with different example values."
) )
content = merge_content(self.content, other.content) content = merge_content(left.content, *(o.content for o in others))
additional_kwargs = merge_dicts( additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs left.additional_kwargs, *(o.additional_kwargs for o in others)
) )
response_metadata = merge_dicts( response_metadata = merge_dicts(
self.response_metadata, other.response_metadata left.response_metadata, *(o.response_metadata for o in others)
) )
# Merge tool call chunks # Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks: if raw_tool_calls := merge_lists(
raw_tool_calls = merge_lists( left.tool_call_chunks, *(o.tool_call_chunks for o in others)
self.tool_call_chunks, ):
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [ tool_call_chunks = [
ToolCallChunk( ToolCallChunk(
name=rtc.get("name"), name=rtc.get("name"),
@ -298,33 +307,29 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
] ]
else: else:
tool_call_chunks = [] tool_call_chunks = []
else:
tool_call_chunks = []
# Token usage # Token usage
if self.usage_metadata or other.usage_metadata: if left.usage_metadata or any(o.usage_metadata is not None for o in others):
left: UsageMetadata = self.usage_metadata or UsageMetadata( usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0 input_tokens=0, output_tokens=0, total_tokens=0
) )
right: UsageMetadata = other.usage_metadata or UsageMetadata( for other in others:
input_tokens=0, output_tokens=0, total_tokens=0 if other.usage_metadata is not None:
) usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata: Optional[UsageMetadata] = { usage_metadata_["output_tokens"] += other.usage_metadata[
"input_tokens": left["input_tokens"] + right["input_tokens"], "output_tokens"
"output_tokens": left["output_tokens"] + right["output_tokens"], ]
"total_tokens": left["total_tokens"] + right["total_tokens"], usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
} usage_metadata: Optional[UsageMetadata] = usage_metadata_
else: else:
usage_metadata = None usage_metadata = None
return self.__class__( return left.__class__(
example=self.example, example=left.example,
content=content, content=content,
additional_kwargs=additional_kwargs, additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks, tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata, response_metadata=response_metadata,
usage_metadata=usage_metadata, usage_metadata=usage_metadata,
id=self.id, id=left.id,
) )
return super().__add__(other)

View File

@ -111,7 +111,7 @@ class BaseMessage(Serializable):
def merge_content( def merge_content(
first_content: Union[str, List[Union[str, Dict]]], first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]], *contents: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]: ) -> Union[str, List[Union[str, Dict]]]:
"""Merge two message contents. """Merge two message contents.
@ -122,31 +122,32 @@ def merge_content(
Returns: Returns:
The merged content. The merged content.
""" """
# If first chunk is a string merged = first_content
if isinstance(first_content, str): for content in contents:
# If the second chunk is also a string, then merge them naively # If current is a string
if isinstance(second_content, str): if isinstance(merged, str):
return first_content + second_content # If the next chunk is also a string, then merge them naively
# If the second chunk is a list, add the first chunk to the start of the list if isinstance(content, str):
merged = cast(str, merged) + content
# If the next chunk is a list, add the current to the start of the list
else: else:
return_list: List[Union[str, Dict]] = [first_content] merged = [merged] + content # type: ignore
return return_list + second_content elif isinstance(content, list):
elif isinstance(second_content, List):
# If both are lists # If both are lists
merged_list = merge_lists(first_content, second_content) merged = merge_lists(cast(List, merged), content) # type: ignore
return cast(list, merged_list)
# If the first content is a list, and the second content is a string # If the first content is a list, and the second content is a string
else: else:
# If the last element of the first content is a string # If the last element of the first content is a string
# Add the second content to the last element # Add the second content to the last element
if isinstance(first_content[-1], str): if isinstance(merged[-1], str):
return first_content[:-1] + [first_content[-1] + second_content] merged[-1] += content
# If second content is an empty string, treat as a no-op # If second content is an empty string, treat as a no-op
elif second_content == "": elif content == "":
return first_content pass
else: else:
# Otherwise, add the second content as a new element of the list # Otherwise, add the second content as a new element of the list
return first_content + [second_content] merged.append(content)
return merged
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
@ -195,6 +196,22 @@ class BaseMessageChunk(BaseMessage):
self.response_metadata, other.response_metadata self.response_metadata, other.response_metadata
), ),
) )
elif isinstance(other, list) and all(
isinstance(o, BaseMessageChunk) for o in other
):
content = merge_content(self.content, *(o.content for o in other))
additional_kwargs = merge_dicts(
self.additional_kwargs, *(o.additional_kwargs for o in other)
)
response_metadata = merge_dicts(
self.response_metadata, *(o.response_metadata for o in other)
)
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=content,
additional_kwargs=additional_kwargs,
response_metadata=response_metadata,
)
else: else:
raise TypeError( raise TypeError(
'unsupported operand type(s) for +: "' 'unsupported operand type(s) for +: "'

View File

@ -17,6 +17,7 @@ from langchain_core.outputs import (
Generation, Generation,
GenerationChunk, GenerationChunk,
) )
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -37,9 +38,13 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)]) yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
else: else:
yield self.parse_result([Generation(text=chunk)]) yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
)
def transform( def transform(
self, self,
@ -153,7 +158,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
parsed = await self.aparse_result([acc_gen], partial=True) parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
if self.diff: if self.diff:
yield self._diff(prev_parsed, parsed) yield await run_in_executor(None, self._diff, prev_parsed, parsed)
else: else:
yield parsed yield parsed
prev_parsed = parsed prev_parsed = parsed

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Literal from typing import Any, Dict, List, Literal, Union
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation from langchain_core.outputs.generation import Generation
@ -88,7 +88,9 @@ class ChatGenerationChunk(ChatGeneration):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts( generation_info = merge_dicts(
self.generation_info or {}, self.generation_info or {},
@ -98,6 +100,17 @@ class ChatGenerationChunk(ChatGeneration):
message=self.message + other.message, message=self.message + other.message,
generation_info=generation_info or None, generation_info=generation_info or None,
) )
elif isinstance(other, list) and all(
isinstance(x, ChatGenerationChunk) for x in other
):
generation_info = merge_dicts(
self.generation_info or {},
*[chunk.generation_info for chunk in other if chunk.generation_info],
)
return ChatGenerationChunk(
message=self.message + [chunk.message for chunk in other],
generation_info=generation_info or None,
)
else: else:
raise TypeError( raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"

View File

@ -3,8 +3,8 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dicts, handling specific scenarios where a key exists in both """Merge many dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary. value from 'right' for that key in the merged dictionary.
@ -16,6 +16,7 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}. resulting in merged = {"function_call": {"arguments": "{\n"}}.
""" """
merged = left.copy() merged = left.copy()
for right in others:
for right_k, right_v in right.items(): for right_k, right_v in right.items():
if right_k not in merged: if right_k not in merged:
merged[right_k] = right_v merged[right_k] = right_v
@ -44,15 +45,16 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
return merged return merged
def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]: def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add two lists, handling None.""" """Add many lists, handling None."""
if left is None and right is None: merged = left.copy() if left is not None else None
return None for other in others:
elif left is None or right is None: if other is None:
return left or right continue
elif merged is None:
merged = other.copy()
else: else:
merged = left.copy() for e in other:
for e in right:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int): if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
to_merge = [ to_merge = [
i i
@ -66,7 +68,7 @@ def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]:
e.pop("type") e.pop("type")
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e) merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
else: else:
merged = merged + [e] merged.append(e)
else: else:
merged = merged + [e] merged.append(e)
return merged return merged