mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-16 17:53:37 +00:00
core[patch]: include tool_calls in ai msg chunk serialization (#20291)
This commit is contained in:
parent
0fa551c278
commit
03b247cca1
@ -1,5 +1,5 @@
|
||||
import warnings
|
||||
from typing import Any, List, Literal
|
||||
from typing import Any, Dict, List, Literal
|
||||
|
||||
from langchain_core.messages.base import (
|
||||
BaseMessage,
|
||||
@ -40,7 +40,15 @@ class AIMessage(BaseMessage):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@root_validator
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||
tool_calls = (
|
||||
@ -88,6 +96,14 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
"""Get the namespace of the langchain object."""
|
||||
return ["langchain", "schema", "messages"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict:
|
||||
"""Attrs to be serialized even if they are derived from other init args."""
|
||||
return {
|
||||
"tool_calls": self.tool_calls,
|
||||
"invalid_tool_calls": self.invalid_tool_calls,
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def init_tool_calls(cls, values: dict) -> dict:
|
||||
if not values["tool_call_chunks"]:
|
||||
|
67
libs/core/tests/unit_tests/messages/test_ai.py
Normal file
67
libs/core/tests/unit_tests/messages/test_ai.py
Normal file
@ -0,0 +1,67 @@
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
InvalidToolCall,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
)
|
||||
|
||||
|
||||
def test_serdes_message() -> None:
|
||||
msg = AIMessage(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
|
||||
invalid_tool_calls=[
|
||||
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessage"],
|
||||
"kwargs": {
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"invalid_tool_calls": [
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
|
||||
],
|
||||
},
|
||||
}
|
||||
actual = dumpd(msg)
|
||||
assert actual == expected
|
||||
assert load(actual) == msg
|
||||
|
||||
|
||||
def test_serdes_message_chunk() -> None:
|
||||
chunk = AIMessageChunk(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": ["langchain", "schema", "messages", "AIMessageChunk"],
|
||||
"kwargs": {
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"invalid_tool_calls": [
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"error": "Malformed args.",
|
||||
}
|
||||
],
|
||||
"tool_call_chunks": [
|
||||
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
|
||||
],
|
||||
},
|
||||
}
|
||||
actual = dumpd(chunk)
|
||||
assert actual == expected
|
||||
assert load(actual) == chunk
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user