core: Move json parsing in base chat model / output parser to bg thread (#24031)

- add version of AIMessageChunk.__add__ that can add many chunks,
instead of only 2
- In agenerate_from_stream merge and parse chunks in bg thread
- In output parse base classes do more work in bg threads where
appropriate

---------

Co-authored-by: William FH <13333726+hinthornw@users.noreply.github.com>
This commit is contained in:
Nuno Campos 2024-07-09 20:26:36 +01:00 committed by GitHub
parent 73966e693c
commit 160fc7f246
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 191 additions and 164 deletions

View File

@ -94,13 +94,11 @@ def generate_from_stream(stream: Iterator[ChatGenerationChunk]) -> ChatResult:
ChatResult: Chat result. ChatResult: Chat result.
""" """
generation: Optional[ChatGenerationChunk] = None generation = next(stream, None)
for chunk in stream: if generation:
if generation is None: generation += list(stream)
generation = chunk if generation is None:
else: raise ValueError("No generations found in stream.")
generation += chunk
assert generation is not None
return ChatResult( return ChatResult(
generations=[ generations=[
ChatGeneration( ChatGeneration(
@ -123,21 +121,8 @@ async def agenerate_from_stream(
ChatResult: Chat result. ChatResult: Chat result.
""" """
generation: Optional[ChatGenerationChunk] = None chunks = [chunk async for chunk in stream]
async for chunk in stream: return await run_in_executor(None, generate_from_stream, iter(chunks))
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
return ChatResult(
generations=[
ChatGeneration(
message=message_chunk_to_message(generation.message),
generation_info=generation.generation_info,
)
]
)
class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):

View File

@ -267,64 +267,69 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
if isinstance(other, AIMessageChunk): if isinstance(other, AIMessageChunk):
if self.example != other.example: return add_ai_message_chunks(self, other)
raise ValueError( elif isinstance(other, (list, tuple)) and all(
"Cannot concatenate AIMessageChunks with different example values." isinstance(o, AIMessageChunk) for o in other
) ):
return add_ai_message_chunks(self, *other)
content = merge_content(self.content, other.content)
additional_kwargs = merge_dicts(
self.additional_kwargs, other.additional_kwargs
)
response_metadata = merge_dicts(
self.response_metadata, other.response_metadata
)
# Merge tool call chunks
if self.tool_call_chunks or other.tool_call_chunks:
raw_tool_calls = merge_lists(
self.tool_call_chunks,
other.tool_call_chunks,
)
if raw_tool_calls:
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
else:
tool_call_chunks = []
# Token usage
if self.usage_metadata or other.usage_metadata:
left: UsageMetadata = self.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
right: UsageMetadata = other.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
usage_metadata: Optional[UsageMetadata] = {
"input_tokens": left["input_tokens"] + right["input_tokens"],
"output_tokens": left["output_tokens"] + right["output_tokens"],
"total_tokens": left["total_tokens"] + right["total_tokens"],
}
else:
usage_metadata = None
return self.__class__(
example=self.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=self.id,
)
return super().__add__(other) return super().__add__(other)
def add_ai_message_chunks(
left: AIMessageChunk, *others: AIMessageChunk
) -> AIMessageChunk:
"""Add multiple AIMessageChunks together."""
if any(left.example != o.example for o in others):
raise ValueError(
"Cannot concatenate AIMessageChunks with different example values."
)
content = merge_content(left.content, *(o.content for o in others))
additional_kwargs = merge_dicts(
left.additional_kwargs, *(o.additional_kwargs for o in others)
)
response_metadata = merge_dicts(
left.response_metadata, *(o.response_metadata for o in others)
)
# Merge tool call chunks
if raw_tool_calls := merge_lists(
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
):
tool_call_chunks = [
ToolCallChunk(
name=rtc.get("name"),
args=rtc.get("args"),
index=rtc.get("index"),
id=rtc.get("id"),
)
for rtc in raw_tool_calls
]
else:
tool_call_chunks = []
# Token usage
if left.usage_metadata or any(o.usage_metadata is not None for o in others):
usage_metadata_: UsageMetadata = left.usage_metadata or UsageMetadata(
input_tokens=0, output_tokens=0, total_tokens=0
)
for other in others:
if other.usage_metadata is not None:
usage_metadata_["input_tokens"] += other.usage_metadata["input_tokens"]
usage_metadata_["output_tokens"] += other.usage_metadata[
"output_tokens"
]
usage_metadata_["total_tokens"] += other.usage_metadata["total_tokens"]
usage_metadata: Optional[UsageMetadata] = usage_metadata_
else:
usage_metadata = None
return left.__class__(
example=left.example,
content=content,
additional_kwargs=additional_kwargs,
tool_call_chunks=tool_call_chunks,
response_metadata=response_metadata,
usage_metadata=usage_metadata,
id=left.id,
)

View File

@ -111,7 +111,7 @@ class BaseMessage(Serializable):
def merge_content( def merge_content(
first_content: Union[str, List[Union[str, Dict]]], first_content: Union[str, List[Union[str, Dict]]],
second_content: Union[str, List[Union[str, Dict]]], *contents: Union[str, List[Union[str, Dict]]],
) -> Union[str, List[Union[str, Dict]]]: ) -> Union[str, List[Union[str, Dict]]]:
"""Merge two message contents. """Merge two message contents.
@ -122,31 +122,32 @@ def merge_content(
Returns: Returns:
The merged content. The merged content.
""" """
# If first chunk is a string merged = first_content
if isinstance(first_content, str): for content in contents:
# If the second chunk is also a string, then merge them naively # If current is a string
if isinstance(second_content, str): if isinstance(merged, str):
return first_content + second_content # If the next chunk is also a string, then merge them naively
# If the second chunk is a list, add the first chunk to the start of the list if isinstance(content, str):
merged = cast(str, merged) + content
# If the next chunk is a list, add the current to the start of the list
else:
merged = [merged] + content # type: ignore
elif isinstance(content, list):
# If both are lists
merged = merge_lists(cast(List, merged), content) # type: ignore
# If the first content is a list, and the second content is a string
else: else:
return_list: List[Union[str, Dict]] = [first_content] # If the last element of the first content is a string
return return_list + second_content # Add the second content to the last element
elif isinstance(second_content, List): if isinstance(merged[-1], str):
# If both are lists merged[-1] += content
merged_list = merge_lists(first_content, second_content) # If second content is an empty string, treat as a no-op
return cast(list, merged_list) elif content == "":
# If the first content is a list, and the second content is a string pass
else: else:
# If the last element of the first content is a string # Otherwise, add the second content as a new element of the list
# Add the second content to the last element merged.append(content)
if isinstance(first_content[-1], str): return merged
return first_content[:-1] + [first_content[-1] + second_content]
# If second content is an empty string, treat as a no-op
elif second_content == "":
return first_content
else:
# Otherwise, add the second content as a new element of the list
return first_content + [second_content]
class BaseMessageChunk(BaseMessage): class BaseMessageChunk(BaseMessage):
@ -195,6 +196,22 @@ class BaseMessageChunk(BaseMessage):
self.response_metadata, other.response_metadata self.response_metadata, other.response_metadata
), ),
) )
elif isinstance(other, list) and all(
isinstance(o, BaseMessageChunk) for o in other
):
content = merge_content(self.content, *(o.content for o in other))
additional_kwargs = merge_dicts(
self.additional_kwargs, *(o.additional_kwargs for o in other)
)
response_metadata = merge_dicts(
self.response_metadata, *(o.response_metadata for o in other)
)
return self.__class__( # type: ignore[call-arg]
id=self.id,
content=content,
additional_kwargs=additional_kwargs,
response_metadata=response_metadata,
)
else: else:
raise TypeError( raise TypeError(
'unsupported operand type(s) for +: "' 'unsupported operand type(s) for +: "'

View File

@ -17,6 +17,7 @@ from langchain_core.outputs import (
Generation, Generation,
GenerationChunk, GenerationChunk,
) )
from langchain_core.runnables.config import run_in_executor
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@ -37,9 +38,13 @@ class BaseTransformOutputParser(BaseOutputParser[T]):
) -> AsyncIterator[T]: ) -> AsyncIterator[T]:
async for chunk in input: async for chunk in input:
if isinstance(chunk, BaseMessage): if isinstance(chunk, BaseMessage):
yield self.parse_result([ChatGeneration(message=chunk)]) yield await run_in_executor(
None, self.parse_result, [ChatGeneration(message=chunk)]
)
else: else:
yield self.parse_result([Generation(text=chunk)]) yield await run_in_executor(
None, self.parse_result, [Generation(text=chunk)]
)
def transform( def transform(
self, self,
@ -153,7 +158,7 @@ class BaseCumulativeTransformOutputParser(BaseTransformOutputParser[T]):
parsed = await self.aparse_result([acc_gen], partial=True) parsed = await self.aparse_result([acc_gen], partial=True)
if parsed is not None and parsed != prev_parsed: if parsed is not None and parsed != prev_parsed:
if self.diff: if self.diff:
yield self._diff(prev_parsed, parsed) yield await run_in_executor(None, self._diff, prev_parsed, parsed)
else: else:
yield parsed yield parsed
prev_parsed = parsed prev_parsed = parsed

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Literal from typing import Any, Dict, List, Literal, Union
from langchain_core.messages import BaseMessage, BaseMessageChunk from langchain_core.messages import BaseMessage, BaseMessageChunk
from langchain_core.outputs.generation import Generation from langchain_core.outputs.generation import Generation
@ -88,7 +88,9 @@ class ChatGenerationChunk(ChatGeneration):
"""Get the namespace of the langchain object.""" """Get the namespace of the langchain object."""
return ["langchain", "schema", "output"] return ["langchain", "schema", "output"]
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk: def __add__(
self, other: Union[ChatGenerationChunk, List[ChatGenerationChunk]]
) -> ChatGenerationChunk:
if isinstance(other, ChatGenerationChunk): if isinstance(other, ChatGenerationChunk):
generation_info = merge_dicts( generation_info = merge_dicts(
self.generation_info or {}, self.generation_info or {},
@ -98,6 +100,17 @@ class ChatGenerationChunk(ChatGeneration):
message=self.message + other.message, message=self.message + other.message,
generation_info=generation_info or None, generation_info=generation_info or None,
) )
elif isinstance(other, list) and all(
isinstance(x, ChatGenerationChunk) for x in other
):
generation_info = merge_dicts(
self.generation_info or {},
*[chunk.generation_info for chunk in other if chunk.generation_info],
)
return ChatGenerationChunk(
message=self.message + [chunk.message for chunk in other],
generation_info=generation_info or None,
)
else: else:
raise TypeError( raise TypeError(
f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'" f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'"

View File

@ -3,8 +3,8 @@ from __future__ import annotations
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]: def merge_dicts(left: Dict[str, Any], *others: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two dicts, handling specific scenarios where a key exists in both """Merge many dicts, handling specific scenarios where a key exists in both
dictionaries but has a value of None in 'left'. In such cases, the method uses the dictionaries but has a value of None in 'left'. In such cases, the method uses the
value from 'right' for that key in the merged dictionary. value from 'right' for that key in the merged dictionary.
@ -16,57 +16,59 @@ def merge_dicts(left: Dict[str, Any], right: Dict[str, Any]) -> Dict[str, Any]:
resulting in merged = {"function_call": {"arguments": "{\n"}}. resulting in merged = {"function_call": {"arguments": "{\n"}}.
""" """
merged = left.copy() merged = left.copy()
for right_k, right_v in right.items(): for right in others:
if right_k not in merged: for right_k, right_v in right.items():
merged[right_k] = right_v if right_k not in merged:
elif right_v is not None and merged[right_k] is None: merged[right_k] = right_v
merged[right_k] = right_v elif right_v is not None and merged[right_k] is None:
elif right_v is None: merged[right_k] = right_v
continue elif right_v is None:
elif type(merged[right_k]) is not type(right_v): continue
raise TypeError( elif type(merged[right_k]) is not type(right_v):
f'additional_kwargs["{right_k}"] already exists in this message,' raise TypeError(
" but with a different type." f'additional_kwargs["{right_k}"] already exists in this message,'
) " but with a different type."
elif isinstance(merged[right_k], str): )
merged[right_k] += right_v elif isinstance(merged[right_k], str):
elif isinstance(merged[right_k], dict): merged[right_k] += right_v
merged[right_k] = merge_dicts(merged[right_k], right_v) elif isinstance(merged[right_k], dict):
elif isinstance(merged[right_k], list): merged[right_k] = merge_dicts(merged[right_k], right_v)
merged[right_k] = merge_lists(merged[right_k], right_v) elif isinstance(merged[right_k], list):
elif merged[right_k] == right_v: merged[right_k] = merge_lists(merged[right_k], right_v)
continue elif merged[right_k] == right_v:
else: continue
raise TypeError( else:
f"Additional kwargs key {right_k} already exists in left dict and " raise TypeError(
f"value has unsupported type {type(merged[right_k])}." f"Additional kwargs key {right_k} already exists in left dict and "
) f"value has unsupported type {type(merged[right_k])}."
)
return merged return merged
def merge_lists(left: Optional[List], right: Optional[List]) -> Optional[List]: def merge_lists(left: Optional[List], *others: Optional[List]) -> Optional[List]:
"""Add two lists, handling None.""" """Add many lists, handling None."""
if left is None and right is None: merged = left.copy() if left is not None else None
return None for other in others:
elif left is None or right is None: if other is None:
return left or right continue
else: elif merged is None:
merged = left.copy() merged = other.copy()
for e in right: else:
if isinstance(e, dict) and "index" in e and isinstance(e["index"], int): for e in other:
to_merge = [ if isinstance(e, dict) and "index" in e and isinstance(e["index"], int):
i to_merge = [
for i, e_left in enumerate(merged) i
if e_left["index"] == e["index"] for i, e_left in enumerate(merged)
] if e_left["index"] == e["index"]
if to_merge: ]
# If a top-level "type" has been set for a chunk, it should no if to_merge:
# longer be overridden by the "type" field in future chunks. # If a top-level "type" has been set for a chunk, it should no
if "type" in merged[to_merge[0]] and "type" in e: # longer be overridden by the "type" field in future chunks.
e.pop("type") if "type" in merged[to_merge[0]] and "type" in e:
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e) e.pop("type")
merged[to_merge[0]] = merge_dicts(merged[to_merge[0]], e)
else:
merged.append(e)
else: else:
merged = merged + [e] merged.append(e)
else: return merged
merged = merged + [e]
return merged