mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-27 08:58:48 +00:00
core[patch]: fix ToolCall "type" when streaming (#24218)
This commit is contained in:
parent
2b7d1cdd2f
commit
65321bf975
@ -48,6 +48,7 @@ from langchain_core.messages import (
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.outputs import (
|
||||
ChatGeneration,
|
||||
ChatGenerationChunk,
|
||||
@ -96,7 +97,7 @@ def _parse_tool_calling(tool_call: dict) -> ToolCall:
|
||||
name = tool_call["function"].get("name", "")
|
||||
args = json.loads(tool_call["function"]["arguments"])
|
||||
id = tool_call.get("id")
|
||||
return ToolCall(name=name, args=args, id=id)
|
||||
return create_tool_call(name=name, args=args, id=id)
|
||||
|
||||
|
||||
def _convert_to_tool_calling(tool_call: ToolCall) -> Dict[str, Any]:
|
||||
|
@ -36,9 +36,11 @@ from langchain_core.messages import (
|
||||
InvalidToolCall,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.output_parsers.base import OutputParserLike
|
||||
from langchain_core.output_parsers.openai_tools import (
|
||||
JsonOutputKeyToolsParser,
|
||||
@ -63,7 +65,7 @@ def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationCh
|
||||
message = generated_result.generations[0].message
|
||||
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tool_call["name"],
|
||||
args=json.dumps(tool_call["args"]),
|
||||
id=tool_call["id"],
|
||||
@ -189,7 +191,7 @@ def _extract_tool_calls_from_edenai_response(
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
try:
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
create_tool_call(
|
||||
name=raw_tool_call["name"],
|
||||
args=json.loads(raw_tool_call["arguments"]),
|
||||
id=raw_tool_call["id"],
|
||||
@ -197,7 +199,7 @@ def _extract_tool_calls_from_edenai_response(
|
||||
)
|
||||
except json.JSONDecodeError as exc:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
create_invalid_tool_call(
|
||||
name=raw_tool_call.get("name"),
|
||||
args=raw_tool_call.get("arguments"),
|
||||
id=raw_tool_call.get("id"),
|
||||
|
@ -15,11 +15,18 @@ from langchain_core.messages.tool import (
|
||||
default_tool_chunk_parser,
|
||||
default_tool_parser,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
invalid_tool_call as create_invalid_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call as create_tool_call,
|
||||
)
|
||||
from langchain_core.messages.tool import (
|
||||
tool_call_chunk as create_tool_call_chunk,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
from langchain_core.utils._merge import merge_dicts, merge_lists
|
||||
from langchain_core.utils.json import (
|
||||
parse_partial_json,
|
||||
)
|
||||
from langchain_core.utils.json import parse_partial_json
|
||||
|
||||
|
||||
class UsageMetadata(TypedDict):
|
||||
@ -106,24 +113,55 @@ class AIMessage(BaseMessage):
|
||||
|
||||
@root_validator(pre=True)
|
||||
def _backwards_compat_tool_calls(cls, values: dict) -> dict:
|
||||
raw_tool_calls = values.get("additional_kwargs", {}).get("tool_calls")
|
||||
tool_calls = (
|
||||
values.get("tool_calls")
|
||||
or values.get("invalid_tool_calls")
|
||||
or values.get("tool_call_chunks")
|
||||
check_additional_kwargs = not any(
|
||||
values.get(k)
|
||||
for k in ("tool_calls", "invalid_tool_calls", "tool_call_chunks")
|
||||
)
|
||||
if raw_tool_calls and not tool_calls:
|
||||
if check_additional_kwargs and (
|
||||
raw_tool_calls := values.get("additional_kwargs", {}).get("tool_calls")
|
||||
):
|
||||
try:
|
||||
if issubclass(cls, AIMessageChunk): # type: ignore
|
||||
values["tool_call_chunks"] = default_tool_chunk_parser(
|
||||
raw_tool_calls
|
||||
)
|
||||
else:
|
||||
tool_calls, invalid_tool_calls = default_tool_parser(raw_tool_calls)
|
||||
values["tool_calls"] = tool_calls
|
||||
values["invalid_tool_calls"] = invalid_tool_calls
|
||||
parsed_tool_calls, parsed_invalid_tool_calls = default_tool_parser(
|
||||
raw_tool_calls
|
||||
)
|
||||
values["tool_calls"] = parsed_tool_calls
|
||||
values["invalid_tool_calls"] = parsed_invalid_tool_calls
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Ensure "type" is properly set on all tool call-like dicts.
|
||||
if tool_calls := values.get("tool_calls"):
|
||||
updated: List = []
|
||||
for tc in tool_calls:
|
||||
updated.append(
|
||||
create_tool_call(**{k: v for k, v in tc.items() if k != "type"})
|
||||
)
|
||||
values["tool_calls"] = updated
|
||||
if invalid_tool_calls := values.get("invalid_tool_calls"):
|
||||
updated = []
|
||||
for tc in invalid_tool_calls:
|
||||
updated.append(
|
||||
create_invalid_tool_call(
|
||||
**{k: v for k, v in tc.items() if k != "type"}
|
||||
)
|
||||
)
|
||||
values["invalid_tool_calls"] = updated
|
||||
|
||||
if tool_call_chunks := values.get("tool_call_chunks"):
|
||||
updated = []
|
||||
for tc in tool_call_chunks:
|
||||
updated.append(
|
||||
create_tool_call_chunk(
|
||||
**{k: v for k, v in tc.items() if k != "type"}
|
||||
)
|
||||
)
|
||||
values["tool_call_chunks"] = updated
|
||||
|
||||
return values
|
||||
|
||||
def pretty_repr(self, html: bool = False) -> str:
|
||||
@ -216,7 +254,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
if not values["tool_call_chunks"]:
|
||||
if values["tool_calls"]:
|
||||
values["tool_call_chunks"] = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"],
|
||||
args=json.dumps(tc["args"]),
|
||||
id=tc["id"],
|
||||
@ -228,7 +266,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
tool_call_chunks = values.get("tool_call_chunks", [])
|
||||
tool_call_chunks.extend(
|
||||
[
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=tc["name"], args=tc["args"], id=tc["id"], index=None
|
||||
)
|
||||
for tc in values["invalid_tool_calls"]
|
||||
@ -244,7 +282,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
args_ = parse_partial_json(chunk["args"]) if chunk["args"] != "" else {}
|
||||
if isinstance(args_, dict):
|
||||
tool_calls.append(
|
||||
ToolCall(
|
||||
create_tool_call(
|
||||
name=chunk["name"] or "",
|
||||
args=args_,
|
||||
id=chunk["id"],
|
||||
@ -254,7 +292,7 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
|
||||
raise ValueError("Malformed args.")
|
||||
except Exception:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
create_invalid_tool_call(
|
||||
name=chunk["name"],
|
||||
args=chunk["args"],
|
||||
id=chunk["id"],
|
||||
@ -297,7 +335,7 @@ def add_ai_message_chunks(
|
||||
left.tool_call_chunks, *(o.tool_call_chunks for o in others)
|
||||
):
|
||||
tool_call_chunks = [
|
||||
ToolCallChunk(
|
||||
create_tool_call_chunk(
|
||||
name=rtc.get("name"),
|
||||
args=rtc.get("args"),
|
||||
index=rtc.get("index"),
|
||||
|
@ -237,25 +237,25 @@ def default_tool_parser(
|
||||
"""Best-effort parsing of tools."""
|
||||
tool_calls = []
|
||||
invalid_tool_calls = []
|
||||
for tool_call in raw_tool_calls:
|
||||
if "function" not in tool_call:
|
||||
for raw_tool_call in raw_tool_calls:
|
||||
if "function" not in raw_tool_call:
|
||||
continue
|
||||
else:
|
||||
function_name = tool_call["function"]["name"]
|
||||
function_name = raw_tool_call["function"]["name"]
|
||||
try:
|
||||
function_args = json.loads(tool_call["function"]["arguments"])
|
||||
parsed = ToolCall(
|
||||
function_args = json.loads(raw_tool_call["function"]["arguments"])
|
||||
parsed = tool_call(
|
||||
name=function_name or "",
|
||||
args=function_args or {},
|
||||
id=tool_call.get("id"),
|
||||
id=raw_tool_call.get("id"),
|
||||
)
|
||||
tool_calls.append(parsed)
|
||||
except json.JSONDecodeError:
|
||||
invalid_tool_calls.append(
|
||||
InvalidToolCall(
|
||||
invalid_tool_call(
|
||||
name=function_name,
|
||||
args=tool_call["function"]["arguments"],
|
||||
id=tool_call.get("id"),
|
||||
args=raw_tool_call["function"]["arguments"],
|
||||
id=raw_tool_call.get("id"),
|
||||
error=None,
|
||||
)
|
||||
)
|
||||
@ -272,7 +272,7 @@ def default_tool_chunk_parser(raw_tool_calls: List[dict]) -> List[ToolCallChunk]
|
||||
else:
|
||||
function_args = tool_call["function"]["arguments"]
|
||||
function_name = tool_call["function"]["name"]
|
||||
parsed = ToolCallChunk(
|
||||
parsed = tool_call_chunk(
|
||||
name=function_name,
|
||||
args=function_args,
|
||||
id=tool_call.get("id"),
|
||||
|
@ -451,12 +451,12 @@ def merge_message_runs(
|
||||
HumanMessage("wait your favorite food", id="bar",),
|
||||
AIMessage(
|
||||
"my favorite colo",
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123")],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": 2}, id="123", type="tool_call")],
|
||||
id="baz",
|
||||
),
|
||||
AIMessage(
|
||||
[{"type": "text", "text": "my favorite dish is lasagna"}],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456")],
|
||||
tool_calls=[ToolCall(name="blah_tool", args={"x": -10}, id="456", type="tool_call")],
|
||||
id="blur",
|
||||
),
|
||||
]
|
||||
@ -474,8 +474,8 @@ def merge_message_runs(
|
||||
{"type": "text", "text": "my favorite dish is lasagna"}
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123"),
|
||||
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456")
|
||||
ToolCall({"name": "blah_tool", "args": {"x": 2}, "id": "123", "type": "tool_call"}),
|
||||
ToolCall({"name": "blah_tool", "args": {"x": -10}, "id": "456", "type": "tool_call"})
|
||||
]
|
||||
id="baz"
|
||||
),
|
||||
|
@ -1,19 +1,16 @@
|
||||
from langchain_core.load import dumpd, load
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
InvalidToolCall,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
|
||||
|
||||
def test_serdes_message() -> None:
|
||||
msg = AIMessage(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_calls=[ToolCall(name="foo", args={"bar": 1}, id="baz")],
|
||||
tool_calls=[create_tool_call(name="foo", args={"bar": 1}, id="baz")],
|
||||
invalid_tool_calls=[
|
||||
InvalidToolCall(name="foobad", args="blah", id="booz", error="bad")
|
||||
create_invalid_tool_call(name="foobad", args="blah", id="booz", error="bad")
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
@ -23,9 +20,17 @@ def test_serdes_message() -> None:
|
||||
"kwargs": {
|
||||
"type": "ai",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"tool_calls": [
|
||||
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
|
||||
],
|
||||
"invalid_tool_calls": [
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "error": "bad"}
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"error": "bad",
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
@ -38,8 +43,13 @@ def test_serdes_message_chunk() -> None:
|
||||
chunk = AIMessageChunk(
|
||||
content=[{"text": "blah", "type": "text"}],
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||
ToolCallChunk(name="foobad", args="blah", id="booz", index=1),
|
||||
create_tool_call_chunk(name="foo", args='{"bar": 1}', id="baz", index=0),
|
||||
create_tool_call_chunk(
|
||||
name="foobad",
|
||||
args="blah",
|
||||
id="booz",
|
||||
index=1,
|
||||
),
|
||||
],
|
||||
)
|
||||
expected = {
|
||||
@ -49,18 +59,33 @@ def test_serdes_message_chunk() -> None:
|
||||
"kwargs": {
|
||||
"type": "AIMessageChunk",
|
||||
"content": [{"text": "blah", "type": "text"}],
|
||||
"tool_calls": [{"name": "foo", "args": {"bar": 1}, "id": "baz"}],
|
||||
"tool_calls": [
|
||||
{"name": "foo", "args": {"bar": 1}, "id": "baz", "type": "tool_call"}
|
||||
],
|
||||
"invalid_tool_calls": [
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"error": None,
|
||||
"type": "invalid_tool_call",
|
||||
}
|
||||
],
|
||||
"tool_call_chunks": [
|
||||
{"name": "foo", "args": '{"bar": 1}', "id": "baz", "index": 0},
|
||||
{"name": "foobad", "args": "blah", "id": "booz", "index": 1},
|
||||
{
|
||||
"name": "foo",
|
||||
"args": '{"bar": 1}',
|
||||
"id": "baz",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
{
|
||||
"name": "foobad",
|
||||
"args": "blah",
|
||||
"id": "booz",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
@ -35,12 +35,16 @@ def test_merge_message_runs_content() -> None:
|
||||
{"text": "bar", "type": "text"},
|
||||
{"image_url": "...", "type": "image_url"},
|
||||
],
|
||||
tool_calls=[ToolCall(name="foo_tool", args={"x": 1}, id="tool1")],
|
||||
tool_calls=[
|
||||
ToolCall(name="foo_tool", args={"x": 1}, id="tool1", type="tool_call")
|
||||
],
|
||||
id="2",
|
||||
),
|
||||
AIMessage(
|
||||
"baz",
|
||||
tool_calls=[ToolCall(name="foo_tool", args={"x": 5}, id="tool2")],
|
||||
tool_calls=[
|
||||
ToolCall(name="foo_tool", args={"x": 5}, id="tool2", type="tool_call")
|
||||
],
|
||||
id="3",
|
||||
),
|
||||
]
|
||||
@ -54,8 +58,8 @@ def test_merge_message_runs_content() -> None:
|
||||
"baz",
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCall(name="foo_tool", args={"x": 1}, id="tool1"),
|
||||
ToolCall(name="foo_tool", args={"x": 5}, id="tool2"),
|
||||
ToolCall(name="foo_tool", args={"x": 1}, id="tool1", type="tool_call"),
|
||||
ToolCall(name="foo_tool", args={"x": 5}, id="tool2", type="tool_call"),
|
||||
],
|
||||
id="1",
|
||||
),
|
||||
|
@ -15,8 +15,6 @@ from langchain_core.messages import (
|
||||
HumanMessageChunk,
|
||||
RemoveMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
ToolCallChunk,
|
||||
ToolMessage,
|
||||
convert_to_messages,
|
||||
get_buffer_string,
|
||||
@ -25,6 +23,9 @@ from langchain_core.messages import (
|
||||
messages_from_dict,
|
||||
messages_to_dict,
|
||||
)
|
||||
from langchain_core.messages.tool import invalid_tool_call as create_invalid_tool_call
|
||||
from langchain_core.messages.tool import tool_call as create_tool_call
|
||||
from langchain_core.messages.tool import tool_call_chunk as create_tool_call_chunk
|
||||
from langchain_core.utils._merge import merge_lists
|
||||
|
||||
|
||||
@ -77,57 +78,73 @@ def test_message_chunks() -> None:
|
||||
# Test tool calls
|
||||
assert (
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0)
|
||||
create_tool_call_chunk(name="tool1", args="", id="1", index=0)
|
||||
],
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name=None, args='ue}"', id=None, index=0)],
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(
|
||||
name=None, args='{"arg1": "val', id=None, index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(name=None, args='ue}"', id=None, index=0)
|
||||
],
|
||||
)
|
||||
) == AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="tool1", args='{"arg1": "value}"', id="1", index=0)
|
||||
create_tool_call_chunk(
|
||||
name="tool1", args='{"arg1": "value}"', id="1", index=0
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
assert (
|
||||
AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(name="tool1", args="", id="1", index=0)
|
||||
],
|
||||
)
|
||||
+ AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)],
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(name="tool1", args="a", id=None, index=1)
|
||||
],
|
||||
)
|
||||
# Don't merge if `index` field does not match.
|
||||
) == AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="tool1", args="", id="1", index=0),
|
||||
ToolCallChunk(name="tool1", args="a", id=None, index=1),
|
||||
create_tool_call_chunk(name="tool1", args="", id="1", index=0),
|
||||
create_tool_call_chunk(name="tool1", args="a", id=None, index=1),
|
||||
],
|
||||
)
|
||||
|
||||
ai_msg_chunk = AIMessageChunk(content="")
|
||||
tool_calls_msg_chunk = AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="a", id=None, index=1)],
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(name="tool1", args="a", id=None, index=1)
|
||||
],
|
||||
)
|
||||
assert ai_msg_chunk + tool_calls_msg_chunk == tool_calls_msg_chunk
|
||||
assert tool_calls_msg_chunk + ai_msg_chunk == tool_calls_msg_chunk
|
||||
|
||||
ai_msg_chunk = AIMessageChunk(
|
||||
content="",
|
||||
tool_call_chunks=[ToolCallChunk(name="tool1", args="", id="1", index=0)],
|
||||
tool_call_chunks=[
|
||||
create_tool_call_chunk(name="tool1", args="", id="1", index=0)
|
||||
],
|
||||
)
|
||||
assert ai_msg_chunk.tool_calls == [ToolCall(name="tool1", args={}, id="1")]
|
||||
assert ai_msg_chunk.tool_calls == [create_tool_call(name="tool1", args={}, id="1")]
|
||||
|
||||
# Test token usage
|
||||
left = AIMessageChunk(
|
||||
@ -347,11 +364,11 @@ def test_multiple_msg() -> None:
|
||||
msgs = [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(name="a", args={"b": 1}, id=None)],
|
||||
tool_calls=[create_tool_call(name="a", args={"b": 1}, id=None)],
|
||||
),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(name="c", args={"c": 2}, id=None)],
|
||||
tool_calls=[create_tool_call(name="c", args={"c": 2}, id=None)],
|
||||
),
|
||||
]
|
||||
assert messages_from_dict(messages_to_dict(msgs)) == msgs
|
||||
@ -389,21 +406,25 @@ def test_message_chunk_to_message() -> None:
|
||||
chunk = AIMessageChunk(
|
||||
content="I am",
|
||||
tool_call_chunks=[
|
||||
ToolCallChunk(name="tool1", args='{"a": 1}', id="1", index=0),
|
||||
ToolCallChunk(name="tool2", args='{"b": ', id="2", index=0),
|
||||
ToolCallChunk(name="tool3", args=None, id="3", index=0),
|
||||
ToolCallChunk(name="tool4", args="abc", id="4", index=0),
|
||||
create_tool_call_chunk(name="tool1", args='{"a": 1}', id="1", index=0),
|
||||
create_tool_call_chunk(name="tool2", args='{"b": ', id="2", index=0),
|
||||
create_tool_call_chunk(name="tool3", args=None, id="3", index=0),
|
||||
create_tool_call_chunk(name="tool4", args="abc", id="4", index=0),
|
||||
],
|
||||
)
|
||||
expected = AIMessage(
|
||||
content="I am",
|
||||
tool_calls=[
|
||||
{"name": "tool1", "args": {"a": 1}, "id": "1"},
|
||||
{"name": "tool2", "args": {}, "id": "2"},
|
||||
create_tool_call(**{"name": "tool1", "args": {"a": 1}, "id": "1"}), # type: ignore[arg-type]
|
||||
create_tool_call(**{"name": "tool2", "args": {}, "id": "2"}), # type: ignore[arg-type]
|
||||
],
|
||||
invalid_tool_calls=[
|
||||
{"name": "tool3", "args": None, "id": "3", "error": None},
|
||||
{"name": "tool4", "args": "abc", "id": "4", "error": None},
|
||||
create_invalid_tool_call(
|
||||
**{"name": "tool3", "args": None, "id": "3", "error": None}
|
||||
),
|
||||
create_invalid_tool_call(
|
||||
**{"name": "tool4", "args": "abc", "id": "4", "error": None}
|
||||
),
|
||||
],
|
||||
)
|
||||
assert message_chunk_to_message(chunk) == expected
|
||||
@ -632,6 +653,36 @@ def test_tool_calls_merge() -> None:
|
||||
},
|
||||
]
|
||||
},
|
||||
tool_call_chunks=[
|
||||
{
|
||||
"name": "person",
|
||||
"args": '{"name": "jane", "age": 2}',
|
||||
"id": "call_CwGAsESnXehQEjiAIWzinlva",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
{
|
||||
"name": "person",
|
||||
"args": '{"name": "bob", "age": 3}',
|
||||
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
|
||||
"index": 1,
|
||||
"type": "tool_call_chunk",
|
||||
},
|
||||
],
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "person",
|
||||
"args": {"name": "jane", "age": 2},
|
||||
"id": "call_CwGAsESnXehQEjiAIWzinlva",
|
||||
"type": "tool_call",
|
||||
},
|
||||
{
|
||||
"name": "person",
|
||||
"args": {"name": "bob", "age": 3},
|
||||
"id": "call_zXSIylHvc5x3JUAPcHZR5GZI",
|
||||
"type": "tool_call",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -654,7 +705,12 @@ def test_convert_to_messages() -> None:
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"name": "greet", "args": {"name": "Jane"}, "id": "tool_id"}
|
||||
{
|
||||
"name": "greet",
|
||||
"args": {"name": "Jane"},
|
||||
"id": "tool_id",
|
||||
"type": "tool_call",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "tool_id", "content": "Hi!"},
|
||||
@ -682,7 +738,9 @@ def test_convert_to_messages() -> None:
|
||||
FunctionMessage(name="greet", content="Hi!"),
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[ToolCall(name="greet", args={"name": "Jane"}, id="tool_id")],
|
||||
tool_calls=[
|
||||
create_tool_call(name="greet", args={"name": "Jane"}, id="tool_id")
|
||||
],
|
||||
),
|
||||
ToolMessage(tool_call_id="tool_id", content="Hi!"),
|
||||
ToolMessage(tool_call_id="tool_id2", content="Bye!", artifact={"foo": 123}),
|
||||
@ -755,32 +813,60 @@ def test_message_name_chat(MessageClass: Type) -> None:
|
||||
|
||||
|
||||
def test_merge_tool_calls() -> None:
|
||||
tool_call_1 = ToolCallChunk(name="tool1", args="", id="1", index=0)
|
||||
tool_call_2 = ToolCallChunk(name=None, args='{"arg1": "val', id=None, index=0)
|
||||
tool_call_3 = ToolCallChunk(name=None, args='ue}"', id=None, index=0)
|
||||
tool_call_1 = create_tool_call_chunk(name="tool1", args="", id="1", index=0)
|
||||
tool_call_2 = create_tool_call_chunk(
|
||||
name=None, args='{"arg1": "val', id=None, index=0
|
||||
)
|
||||
tool_call_3 = create_tool_call_chunk(name=None, args='ue}"', id=None, index=0)
|
||||
merged = merge_lists([tool_call_1], [tool_call_2])
|
||||
assert merged is not None
|
||||
assert merged == [{"name": "tool1", "args": '{"arg1": "val', "id": "1", "index": 0}]
|
||||
assert merged == [
|
||||
{
|
||||
"name": "tool1",
|
||||
"args": '{"arg1": "val',
|
||||
"id": "1",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
]
|
||||
merged = merge_lists(merged, [tool_call_3])
|
||||
assert merged is not None
|
||||
assert merged == [
|
||||
{"name": "tool1", "args": '{"arg1": "value}"', "id": "1", "index": 0}
|
||||
{
|
||||
"name": "tool1",
|
||||
"args": '{"arg1": "value}"',
|
||||
"id": "1",
|
||||
"index": 0,
|
||||
"type": "tool_call_chunk",
|
||||
}
|
||||
]
|
||||
|
||||
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=None)
|
||||
right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id="1", index=None)
|
||||
left = create_tool_call_chunk(
|
||||
name="tool1", args='{"arg1": "value1"}', id="1", index=None
|
||||
)
|
||||
right = create_tool_call_chunk(
|
||||
name="tool2", args='{"arg2": "value2"}', id="1", index=None
|
||||
)
|
||||
merged = merge_lists([left], [right])
|
||||
assert merged is not None
|
||||
assert len(merged) == 2
|
||||
|
||||
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id=None, index=None)
|
||||
right = ToolCallChunk(name="tool1", args='{"arg2": "value2"}', id=None, index=None)
|
||||
left = create_tool_call_chunk(
|
||||
name="tool1", args='{"arg1": "value1"}', id=None, index=None
|
||||
)
|
||||
right = create_tool_call_chunk(
|
||||
name="tool1", args='{"arg2": "value2"}', id=None, index=None
|
||||
)
|
||||
merged = merge_lists([left], [right])
|
||||
assert merged is not None
|
||||
assert len(merged) == 2
|
||||
|
||||
left = ToolCallChunk(name="tool1", args='{"arg1": "value1"}', id="1", index=0)
|
||||
right = ToolCallChunk(name="tool2", args='{"arg2": "value2"}', id=None, index=1)
|
||||
left = create_tool_call_chunk(
|
||||
name="tool1", args='{"arg1": "value1"}', id="1", index=0
|
||||
)
|
||||
right = create_tool_call_chunk(
|
||||
name="tool2", args='{"arg2": "value2"}', id=None, index=1
|
||||
)
|
||||
merged = merge_lists([left], [right])
|
||||
assert merged is not None
|
||||
assert len(merged) == 2
|
||||
|
4
libs/partners/together/poetry.lock
generated
4
libs/partners/together/poetry.lock
generated
@ -604,7 +604,7 @@ url = "../../core"
|
||||
|
||||
[[package]]
|
||||
name = "langchain-openai"
|
||||
version = "0.1.15"
|
||||
version = "0.1.16"
|
||||
description = "An integration package connecting OpenAI and LangChain"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -1798,4 +1798,4 @@ multidict = ">=4.0"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "32ec0f5b1afd7492c096028403f1f2f94a227cb0d922530ca743e1bd65db3f9f"
|
||||
content-hash = "8e255f5a0e6ecf23a3d04d0eeee9918d411339b5960c24c888521fbd1f6bf531"
|
||||
|
@ -20,7 +20,7 @@ disallow_untyped_defs = "True"
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.2.17,<0.3"
|
||||
langchain-openai = "^0.1.8"
|
||||
langchain-openai = "^0.1.16"
|
||||
requests = "^2"
|
||||
aiohttp = "^3.9.1"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user