mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
fix(core): fix parse_result
in case of self.first_tool_only with multiple keys matching for JsonOutputKeyToolsParser (#32106)
* **Description:** Updated `parse_result` logic to handle cases where `self.first_tool_only` is `True` and multiple matching keys share the same function name. Instead of returning the first match prematurely, the method now prioritizes filtering results by the specified key to ensure correct selection. * **Issue:** #32100 --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
ddaba21e83
commit
095f4a7c28
@@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only(
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
# Test case from the original bug report
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func", "name": "func", "args": {"a": 1}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with return_id=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return the func tool call, not None
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1}
|
||||
assert "id" in output
|
||||
|
||||
# Test with return_id=False
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return just the args
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_another",
|
||||
"function": {"name": "another", "arguments": '{"c":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_another", "name": "another", "args": {"c": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with return_id=True, first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None when no matches
|
||||
assert output is None
|
||||
|
||||
# Test with return_id=False, first_tool_only=True
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None when no matches
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_func1",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func2",
|
||||
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with first_tool_only=True - should return first matching
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1} # First matching tool call
|
||||
|
||||
# Test with first_tool_only=False - should return all matching
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
assert len(output_all) == 2
|
||||
assert output_all[0]["args"] == {"a": 1}
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
if use_tool_calls:
|
||||
return AIMessage(content="", tool_calls=[])
|
||||
return AIMessage(content="", additional_kwargs={"tool_calls": []})
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None for empty results
|
||||
assert output is None
|
||||
|
||||
# Test with first_tool_only=False
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return empty list for empty results
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_parameter_combinations(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func1",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func2",
|
||||
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result: list[ChatGeneration] = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test: first_tool_only=True, return_id=True
|
||||
parser1 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output1 = parser1.parse_result(result) # type: ignore[arg-type]
|
||||
assert output1["type"] == "func"
|
||||
assert output1["args"] == {"a": 1}
|
||||
assert "id" in output1
|
||||
|
||||
# Test: first_tool_only=True, return_id=False
|
||||
parser2 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output2 = parser2.parse_result(result) # type: ignore[arg-type]
|
||||
assert output2 == {"a": 1}
|
||||
|
||||
# Test: first_tool_only=False, return_id=True
|
||||
parser3 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output3 = parser3.parse_result(result) # type: ignore[arg-type]
|
||||
assert len(output3) == 2
|
||||
assert all("id" in item for item in output3)
|
||||
assert output3[0]["args"] == {"a": 1}
|
||||
assert output3[1]["args"] == {"a": 3}
|
||||
|
||||
# Test: first_tool_only=False, return_id=False
|
||||
parser4 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=False
|
||||
)
|
||||
output4 = parser4.parse_result(result) # type: ignore[arg-type]
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
age: int
|
||||
hair_color: str
|
||||
|
Reference in New Issue
Block a user