fix(core): fix parse_resultin case of self.first_tool_only with multiple keys matching for JsonOutputKeyToolsParser (#32106)

* **Description:** Updated `parse_result` logic to handle cases where
`self.first_tool_only` is `True` and multiple matching keys share the
same function name. Instead of returning the first match prematurely,
the method now prioritizes filtering results by the specified key to
ensure correct selection.
* **Issue:** #32100

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Mohammad Mohtashim 2025-07-21 21:50:22 +05:00 committed by GitHub
parent ddaba21e83
commit 095f4a7c28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 336 additions and 33 deletions

View File

@ -234,12 +234,39 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
Returns: Returns:
The parsed tool calls. The parsed tool calls.
""" """
parsed_result = super().parse_result(result, partial=partial) generation = result[0]
if not isinstance(generation, ChatGeneration):
msg = "This output parser can only be used with a chat generation."
raise OutputParserException(msg)
message = generation.message
if isinstance(message, AIMessage) and message.tool_calls:
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
for tool_call in parsed_tool_calls:
if not self.return_id:
_ = tool_call.pop("id")
else:
try:
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
except KeyError:
if self.first_tool_only: if self.first_tool_only:
return None
return []
parsed_tool_calls = parse_tool_calls(
raw_tool_calls,
partial=partial,
strict=self.strict,
return_id=self.return_id,
)
# For backwards compatibility
for tc in parsed_tool_calls:
tc["type"] = tc.pop("name")
if self.first_tool_only:
parsed_result = list(
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
)
single_result = ( single_result = (
parsed_result parsed_result[0]
if parsed_result and parsed_result["type"] == self.key_name if parsed_result and parsed_result[0]["type"] == self.key_name
else None else None
) )
if self.return_id: if self.return_id:
@ -247,10 +274,13 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
if single_result: if single_result:
return single_result["args"] return single_result["args"]
return None return None
parsed_result = [res for res in parsed_result if res["type"] == self.key_name] return (
if not self.return_id: [res for res in parsed_tool_calls if res["type"] == self.key_name]
parsed_result = [res["args"] for res in parsed_result] if self.return_id
return parsed_result else [
res["args"] for res in parsed_tool_calls if res["type"] == self.key_name
]
)
# Common cause of ValidationError is truncated output due to max_tokens. # Common cause of ValidationError is truncated output due to max_tokens.

View File

@ -1,24 +1,23 @@
"""Output classes. """Output classes.
**Output** classes are used to represent the output of a language model call Used to represent the output of a language model call and the output of a chat.
and the output of a chat.
The top container for information is the `LLMResult` object. `LLMResult` is used by The top container for information is the `LLMResult` object. `LLMResult` is used by both
both chat models and LLMs. This object contains the output of the language chat models and LLMs. This object contains the output of the language model and any
model and any additional information that the model provider wants to return. additional information that the model provider wants to return.
When invoking models via the standard runnable methods (e.g. invoke, batch, etc.): When invoking models via the standard runnable methods (e.g. invoke, batch, etc.):
- Chat models will return `AIMessage` objects. - Chat models will return `AIMessage` objects.
- LLMs will return regular text strings. - LLMs will return regular text strings.
In addition, users can access the raw output of either LLMs or chat models via In addition, users can access the raw output of either LLMs or chat models via
callbacks. The on_chat_model_end and on_llm_end callbacks will return an callbacks. The ``on_chat_model_end`` and ``on_llm_end`` callbacks will return an
LLMResult object containing the generated outputs and any additional information LLMResult object containing the generated outputs and any additional information
returned by the model provider. returned by the model provider.
In general, if information is already available In general, if information is already available in the AIMessage object, it is
in the AIMessage object, it is recommended to access it from there rather than recommended to access it from there rather than from the `LLMResult` object.
from the `LLMResult` object.
""" """
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@ -27,7 +27,11 @@ class ChatGeneration(Generation):
""" """
text: str = "" text: str = ""
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" """The text contents of the output message.
.. warning::
SHOULD NOT BE SET DIRECTLY!
"""
message: BaseMessage message: BaseMessage
"""The message output by the chat model.""" """The message output by the chat model."""
# Override type to be ChatGeneration, ignore mypy error as this is intentional # Override type to be ChatGeneration, ignore mypy error as this is intentional

View File

@ -11,7 +11,8 @@ from langchain_core.utils._merge import merge_dicts
class Generation(Serializable): class Generation(Serializable):
"""A single text generation output. """A single text generation output.
Generation represents the response from an "old-fashioned" LLM that Generation represents the response from an
`"old-fashioned" LLM <https://python.langchain.com/docs/concepts/text_llms/>__` that
generates regular text (not chat messages). generates regular text (not chat messages).
This model is used internally by chat model and will eventually This model is used internally by chat model and will eventually

View File

@ -15,9 +15,9 @@ from langchain_core.outputs.run_info import RunInfo
class LLMResult(BaseModel): class LLMResult(BaseModel):
"""A container for results of an LLM call. """A container for results of an LLM call.
Both chat models and LLMs generate an LLMResult object. This object contains Both chat models and LLMs generate an LLMResult object. This object contains the
the generated outputs and any additional information that the model provider generated outputs and any additional information that the model provider wants to
wants to return. return.
""" """
generations: list[ generations: list[
@ -25,17 +25,16 @@ class LLMResult(BaseModel):
] ]
"""Generated outputs. """Generated outputs.
The first dimension of the list represents completions for different input The first dimension of the list represents completions for different input prompts.
prompts.
The second dimension of the list represents different candidate generations The second dimension of the list represents different candidate generations for a
for a given prompt. given prompt.
When returned from an LLM the type is list[list[Generation]]. - When returned from **an LLM**, the type is ``list[list[Generation]]``.
When returned from a chat model the type is list[list[ChatGeneration]]. - When returned from a **chat model**, the type is ``list[list[ChatGeneration]]``.
ChatGeneration is a subclass of Generation that has a field for a structured ChatGeneration is a subclass of Generation that has a field for a structured chat
chat message. message.
""" """
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output. """For arbitrary LLM provider specific output.
@ -43,9 +42,8 @@ class LLMResult(BaseModel):
This dictionary is a free-form dictionary that can contain any information that the This dictionary is a free-form dictionary that can contain any information that the
provider wants to return. It is not standardized and is provider-specific. provider wants to return. It is not standardized and is provider-specific.
Users should generally avoid relying on this field and instead rely on Users should generally avoid relying on this field and instead rely on accessing
accessing relevant information from standardized fields present in relevant information from standardized fields present in AIMessage.
AIMessage.
""" """
run: Optional[list[RunInfo]] = None run: Optional[list[RunInfo]] = None
"""List of metadata info for model call for each input.""" """List of metadata info for model call for each input."""

View File

@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only(
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_first_only(
*, use_tool_calls: bool
) -> None:
# Test case from the original bug report
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func", "name": "func", "args": {"a": 1}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with return_id=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return the func tool call, not None
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1}
assert "id" in output
# Test with return_id=False
parser_no_id = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
# Should return just the args
assert output_no_id == {"a": 1}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_tools_no_match(
*, use_tool_calls: bool
) -> None:
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_another",
"function": {"name": "another", "arguments": '{"c":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_another", "name": "another", "args": {"c": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with return_id=True, first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return None when no matches
assert output is None
# Test with return_id=False, first_tool_only=True
parser_no_id = JsonOutputKeyToolsParser(
key_name="nonexistent", first_tool_only=True, return_id=False
)
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
# Should return None when no matches
assert output_no_id is None
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_multiple_matching_tools(
*, use_tool_calls: bool
) -> None:
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_func1",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func2",
"function": {"name": "func", "arguments": '{"a":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_func1", "name": "func", "args": {"a": 1}},
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result = [ChatGeneration(message=create_message())]
# Test with first_tool_only=True - should return first matching
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
assert output is not None
assert output["type"] == "func"
assert output["args"] == {"a": 1} # First matching tool call
# Test with first_tool_only=False - should return all matching
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
assert len(output_all) == 2
assert output_all[0]["args"] == {"a": 1}
assert output_all[1]["args"] == {"a": 3}
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
def create_message() -> AIMessage:
if use_tool_calls:
return AIMessage(content="", tool_calls=[])
return AIMessage(content="", additional_kwargs={"tool_calls": []})
result = [ChatGeneration(message=create_message())]
# Test with first_tool_only=True
parser = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output = parser.parse_result(result) # type: ignore[arg-type]
# Should return None for empty results
assert output is None
# Test with first_tool_only=False
parser_all = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
# Should return empty list for empty results
assert output_all == []
@pytest.mark.parametrize("use_tool_calls", [False, True])
def test_json_output_key_tools_parser_parameter_combinations(
*, use_tool_calls: bool
) -> None:
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
def create_message() -> AIMessage:
tool_calls_data = [
{
"id": "call_other",
"function": {"name": "other", "arguments": '{"b":2}'},
"type": "function",
},
{
"id": "call_func1",
"function": {"name": "func", "arguments": '{"a":1}'},
"type": "function",
},
{
"id": "call_func2",
"function": {"name": "func", "arguments": '{"a":3}'},
"type": "function",
},
]
if use_tool_calls:
return AIMessage(
content="",
tool_calls=[
{"id": "call_other", "name": "other", "args": {"b": 2}},
{"id": "call_func1", "name": "func", "args": {"a": 1}},
{"id": "call_func2", "name": "func", "args": {"a": 3}},
],
)
return AIMessage(
content="",
additional_kwargs={"tool_calls": tool_calls_data},
)
result: list[ChatGeneration] = [ChatGeneration(message=create_message())]
# Test: first_tool_only=True, return_id=True
parser1 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=True
)
output1 = parser1.parse_result(result) # type: ignore[arg-type]
assert output1["type"] == "func"
assert output1["args"] == {"a": 1}
assert "id" in output1
# Test: first_tool_only=True, return_id=False
parser2 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=True, return_id=False
)
output2 = parser2.parse_result(result) # type: ignore[arg-type]
assert output2 == {"a": 1}
# Test: first_tool_only=False, return_id=True
parser3 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=True
)
output3 = parser3.parse_result(result) # type: ignore[arg-type]
assert len(output3) == 2
assert all("id" in item for item in output3)
assert output3[0]["args"] == {"a": 1}
assert output3[1]["args"] == {"a": 3}
# Test: first_tool_only=False, return_id=False
parser4 = JsonOutputKeyToolsParser(
key_name="func", first_tool_only=False, return_id=False
)
output4 = parser4.parse_result(result) # type: ignore[arg-type]
assert output4 == [{"a": 1}, {"a": 3}]
class Person(BaseModel): class Person(BaseModel):
age: int age: int
hair_color: str hair_color: str