fix(core): fix parse_resultin case of self.first_tool_only with multiple keys matching for JsonOutputKeyToolsParser (#32106)

* **Description:** Updated `parse_result` logic to handle cases where
`self.first_tool_only` is `True` and multiple matching keys share the
same function name. Instead of returning the first match prematurely,
the method now prioritizes filtering results by the specified key to
ensure correct selection.
* **Issue:** #32100

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Mohammad Mohtashim 2025-07-21 21:50:22 +05:00 committed by GitHub
parent ddaba21e83
commit 095f4a7c28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 336 additions and 33 deletions

View File

@ -234,12 +234,39 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Returns:
The parsed tool calls.
"""
parsed_result = super().parse_result(result, partial=partial)
generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
# For backwards compatibility
for tc in parsed_tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
parsed_result = list(
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
)
single_result = (
parsed_result
if parsed_result and parsed_result["type"] == self.key_name
parsed_result[0]
if parsed_result and parsed_result[0]["type"] == self.key_name
else None
)
if self.return_id:
@ -247,10 +274,13 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
if single_result:
return single_result["args"]
return None
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
if not self.return_id:
parsed_result = [res["args"] for res in parsed_result]
return parsed_result
return (
[res for res in parsed_tool_calls if res["type"] == self.key_name]
if self.return_id
else [
res["args"] for res in parsed_tool_calls if res["type"] == self.key_name
]
)
# Common cause of ValidationError is truncated output due to max_tokens.

View File

@ -1,24 +1,23 @@
"""Output classes.
**Output** classes are used to represent the output of a language model call
and the output of a chat.
Used to represent the output of a language model call and the output of a chat.
The top container for information is the `LLMResult` object. `LLMResult` is used by
both chat models and LLMs. This object contains the output of the language
model and any additional information that the model provider wants to return.
The top container for information is the `LLMResult` object. `LLMResult` is used by both
chat models and LLMs. This object contains the output of the language model and any
additional information that the model provider wants to return.
When invoking models via the standard runnable methods (e.g. invoke, batch, etc.):
- Chat models will return `AIMessage` objects.
- LLMs will return regular text strings.
In addition, users can access the raw output of either LLMs or chat models via
callbacks. The on_chat_model_end and on_llm_end callbacks will return an
callbacks. The ``on_chat_model_end`` and ``on_llm_end`` callbacks will return an
LLMResult object containing the generated outputs and any additional information
returned by the model provider.
In general, if information is already available
in the AIMessage object, it is recommended to access it from there rather than
from the `LLMResult` object.
In general, if information is already available in the AIMessage object, it is
recommended to access it from there rather than from the `LLMResult` object.
"""
from typing import TYPE_CHECKING

View File

@ -27,7 +27,11 @@ class ChatGeneration(Generation):
"""
text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
"""The text contents of the output message.
.. warning::
SHOULD NOT BE SET DIRECTLY!
"""
message: BaseMessage
"""The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional

View File

@ -11,7 +11,8 @@ from langchain_core.utils._merge import merge_dicts
class Generation(Serializable):
"""A single text generation output.
Generation represents the response from an "old-fashioned" LLM that
Generation represents the response from an
`"old-fashioned" LLM <https://python.langchain.com/docs/concepts/text_llms/>__` that
generates regular text (not chat messages).
This model is used internally by chat model and will eventually

View File

@ -15,9 +15,9 @@ from langchain_core.outputs.run_info import RunInfo
class LLMResult(BaseModel):
"""A container for results of an LLM call.
Both chat models and LLMs generate an LLMResult object. This object contains
the generated outputs and any additional information that the model provider
wants to return.
Both chat models and LLMs generate an LLMResult object. This object contains the
generated outputs and any additional information that the model provider wants to
return.
"""
generations: list[
@ -25,17 +25,16 @@ class LLMResult(BaseModel):
]
"""Generated outputs.
The first dimension of the list represents completions for different input
prompts.
The first dimension of the list represents completions for different input prompts.
The second dimension of the list represents different candidate generations
for a given prompt.
The second dimension of the list represents different candidate generations for a
given prompt.
When returned from an LLM the type is list[list[Generation]].
When returned from a chat model the type is list[list[ChatGeneration]].
- When returned from **an LLM**, the type is ``list[list[Generation]]``.
- When returned from a **chat model**, the type is ``list[list[ChatGeneration]]``.
ChatGeneration is a subclass of Generation that has a field for a structured
chat message.
ChatGeneration is a subclass of Generation that has a field for a structured chat
message.
"""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output.
@ -43,9 +42,8 @@ class LLMResult(BaseModel):
This dictionary is a free-form dictionary that can contain any information that the
provider wants to return. It is not standardized and is provider-specific.
Users should generally avoid relying on this field and instead rely on
accessing relevant information from standardized fields present in
AIMessage.
Users should generally avoid relying on this field and instead rely on accessing
relevant information from standardized fields present in AIMessage.
"""
run: Optional[list[RunInfo]] = None
"""List of metadata info for model call for each input."""

View File

@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only(
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_first_only(
*, use_tool_calls: bool
) -> None:
# Test case from the original bug report
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func", "name": "func", "args": {"a": 1}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with return_id=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return the func tool call, not None
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1}
assert "id" in output
# Test with return_id=False
parser_no_id = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
# Should return just the args
assert output_no_id == {"a": 1}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_no_match(
*, use_tool_calls: bool
) -> None:
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_another",
"function": {"name": "another", "arguments": '{"c":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_another", "name": "another", "args": {"c": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with return_id=True, first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return None when no matches
assert output is None
# Test with return_id=False, first_tool_only=True
parser_no_id = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
# Should return None when no matches
assert output_no_id is None
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_matching_tools(
*, use_tool_calls: bool
) -> None:
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_func1",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func2",
"function": {"name": "func", "arguments": '{"a":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_func1", "name": "func", "args": {"a": 1}},
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with first_tool_only=True - should return first matching
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1} # First matching tool call
# Test with first_tool_only=False - should return all matching
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
assert len(output_all) == 2
assert output_all[0]["args"] == {"a": 1}
assert output_all[1]["args"] == {"a": 3}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
def create_message() -> AIMessage:
if use_tool_calls:
return AIMessage(content="", tool_calls=[])
return AIMessage(content="", additional_kwargs={"tool_calls": []})
result = [ChatGeneration(message=create_message())]
# Test with first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return None for empty results
assert output is None
# Test with first_tool_only=False
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
# Should return empty list for empty results
assert output_all == []
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_parameter_combinations(
*, use_tool_calls: bool
) -> None:
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func1",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
{
"id": "call_func2",
"function": {"name": "func", "arguments": '{"a":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func1", "name": "func", "args": {"a": 1}},
{"id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result: list[ChatGeneration] = [ChatGeneration(message=create_message())]
# Test: first_tool_only=True, return_id=True
parser1 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output1 = parser1.parse_result(result) # type: ignore[arg-type]
assert output1["type"] == "func"
assert output1["args"] == {"a": 1}
assert "id" in output1
# Test: first_tool_only=True, return_id=False
parser2 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output2 = parser2.parse_result(result) # type: ignore[arg-type]
assert output2 == {"a": 1}
# Test: first_tool_only=False, return_id=True
parser3 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output3 = parser3.parse_result(result) # type: ignore[arg-type]
assert len(output3) == 2
assert all("id" in item for item in output3)
assert output3[0]["args"] == {"a": 1}
assert output3[1]["args"] == {"a": 3}
# Test: first_tool_only=False, return_id=False
parser4 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=False
)
output4 = parser4.parse_result(result) # type: ignore[arg-type]
assert output4 == [{"a": 1}, {"a": 3}]
class Person(BaseModel):
age: int
hair_color: str