mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-03 10:12:33 +00:00
fix(core): fix parse_result
in case of self.first_tool_only with multiple keys matching for JsonOutputKeyToolsParser (#32106)
* **Description:** Updated `parse_result` logic to handle cases where `self.first_tool_only` is `True` and multiple matching keys share the same function name. Instead of returning the first match prematurely, the method now prioritizes filtering results by the specified key to ensure correct selection. * **Issue:** #32100 --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
ddaba21e83
commit
095f4a7c28
@ -234,12 +234,39 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
Returns:
|
Returns:
|
||||||
The parsed tool calls.
|
The parsed tool calls.
|
||||||
"""
|
"""
|
||||||
parsed_result = super().parse_result(result, partial=partial)
|
generation = result[0]
|
||||||
|
if not isinstance(generation, ChatGeneration):
|
||||||
|
msg = "This output parser can only be used with a chat generation."
|
||||||
|
raise OutputParserException(msg)
|
||||||
|
message = generation.message
|
||||||
|
if isinstance(message, AIMessage) and message.tool_calls:
|
||||||
|
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||||
|
for tool_call in parsed_tool_calls:
|
||||||
|
if not self.return_id:
|
||||||
|
_ = tool_call.pop("id")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||||
|
except KeyError:
|
||||||
if self.first_tool_only:
|
if self.first_tool_only:
|
||||||
|
return None
|
||||||
|
return []
|
||||||
|
parsed_tool_calls = parse_tool_calls(
|
||||||
|
raw_tool_calls,
|
||||||
|
partial=partial,
|
||||||
|
strict=self.strict,
|
||||||
|
return_id=self.return_id,
|
||||||
|
)
|
||||||
|
# For backwards compatibility
|
||||||
|
for tc in parsed_tool_calls:
|
||||||
|
tc["type"] = tc.pop("name")
|
||||||
|
if self.first_tool_only:
|
||||||
|
parsed_result = list(
|
||||||
|
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
||||||
|
)
|
||||||
single_result = (
|
single_result = (
|
||||||
parsed_result
|
parsed_result[0]
|
||||||
if parsed_result and parsed_result["type"] == self.key_name
|
if parsed_result and parsed_result[0]["type"] == self.key_name
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
if self.return_id:
|
if self.return_id:
|
||||||
@ -247,10 +274,13 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
|||||||
if single_result:
|
if single_result:
|
||||||
return single_result["args"]
|
return single_result["args"]
|
||||||
return None
|
return None
|
||||||
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
|
return (
|
||||||
if not self.return_id:
|
[res for res in parsed_tool_calls if res["type"] == self.key_name]
|
||||||
parsed_result = [res["args"] for res in parsed_result]
|
if self.return_id
|
||||||
return parsed_result
|
else [
|
||||||
|
res["args"] for res in parsed_tool_calls if res["type"] == self.key_name
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Common cause of ValidationError is truncated output due to max_tokens.
|
# Common cause of ValidationError is truncated output due to max_tokens.
|
||||||
|
@ -1,24 +1,23 @@
|
|||||||
"""Output classes.
|
"""Output classes.
|
||||||
|
|
||||||
**Output** classes are used to represent the output of a language model call
|
Used to represent the output of a language model call and the output of a chat.
|
||||||
and the output of a chat.
|
|
||||||
|
|
||||||
The top container for information is the `LLMResult` object. `LLMResult` is used by
|
The top container for information is the `LLMResult` object. `LLMResult` is used by both
|
||||||
both chat models and LLMs. This object contains the output of the language
|
chat models and LLMs. This object contains the output of the language model and any
|
||||||
model and any additional information that the model provider wants to return.
|
additional information that the model provider wants to return.
|
||||||
|
|
||||||
When invoking models via the standard runnable methods (e.g. invoke, batch, etc.):
|
When invoking models via the standard runnable methods (e.g. invoke, batch, etc.):
|
||||||
|
|
||||||
- Chat models will return `AIMessage` objects.
|
- Chat models will return `AIMessage` objects.
|
||||||
- LLMs will return regular text strings.
|
- LLMs will return regular text strings.
|
||||||
|
|
||||||
In addition, users can access the raw output of either LLMs or chat models via
|
In addition, users can access the raw output of either LLMs or chat models via
|
||||||
callbacks. The on_chat_model_end and on_llm_end callbacks will return an
|
callbacks. The ``on_chat_model_end`` and ``on_llm_end`` callbacks will return an
|
||||||
LLMResult object containing the generated outputs and any additional information
|
LLMResult object containing the generated outputs and any additional information
|
||||||
returned by the model provider.
|
returned by the model provider.
|
||||||
|
|
||||||
In general, if information is already available
|
In general, if information is already available in the AIMessage object, it is
|
||||||
in the AIMessage object, it is recommended to access it from there rather than
|
recommended to access it from there rather than from the `LLMResult` object.
|
||||||
from the `LLMResult` object.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
@ -27,7 +27,11 @@ class ChatGeneration(Generation):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
text: str = ""
|
text: str = ""
|
||||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
"""The text contents of the output message.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
SHOULD NOT BE SET DIRECTLY!
|
||||||
|
"""
|
||||||
message: BaseMessage
|
message: BaseMessage
|
||||||
"""The message output by the chat model."""
|
"""The message output by the chat model."""
|
||||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||||
|
@ -11,7 +11,8 @@ from langchain_core.utils._merge import merge_dicts
|
|||||||
class Generation(Serializable):
|
class Generation(Serializable):
|
||||||
"""A single text generation output.
|
"""A single text generation output.
|
||||||
|
|
||||||
Generation represents the response from an "old-fashioned" LLM that
|
Generation represents the response from an
|
||||||
|
`"old-fashioned" LLM <https://python.langchain.com/docs/concepts/text_llms/>__` that
|
||||||
generates regular text (not chat messages).
|
generates regular text (not chat messages).
|
||||||
|
|
||||||
This model is used internally by chat model and will eventually
|
This model is used internally by chat model and will eventually
|
||||||
|
@ -15,9 +15,9 @@ from langchain_core.outputs.run_info import RunInfo
|
|||||||
class LLMResult(BaseModel):
|
class LLMResult(BaseModel):
|
||||||
"""A container for results of an LLM call.
|
"""A container for results of an LLM call.
|
||||||
|
|
||||||
Both chat models and LLMs generate an LLMResult object. This object contains
|
Both chat models and LLMs generate an LLMResult object. This object contains the
|
||||||
the generated outputs and any additional information that the model provider
|
generated outputs and any additional information that the model provider wants to
|
||||||
wants to return.
|
return.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
generations: list[
|
generations: list[
|
||||||
@ -25,17 +25,16 @@ class LLMResult(BaseModel):
|
|||||||
]
|
]
|
||||||
"""Generated outputs.
|
"""Generated outputs.
|
||||||
|
|
||||||
The first dimension of the list represents completions for different input
|
The first dimension of the list represents completions for different input prompts.
|
||||||
prompts.
|
|
||||||
|
|
||||||
The second dimension of the list represents different candidate generations
|
The second dimension of the list represents different candidate generations for a
|
||||||
for a given prompt.
|
given prompt.
|
||||||
|
|
||||||
When returned from an LLM the type is list[list[Generation]].
|
- When returned from **an LLM**, the type is ``list[list[Generation]]``.
|
||||||
When returned from a chat model the type is list[list[ChatGeneration]].
|
- When returned from a **chat model**, the type is ``list[list[ChatGeneration]]``.
|
||||||
|
|
||||||
ChatGeneration is a subclass of Generation that has a field for a structured
|
ChatGeneration is a subclass of Generation that has a field for a structured chat
|
||||||
chat message.
|
message.
|
||||||
"""
|
"""
|
||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
"""For arbitrary LLM provider specific output.
|
"""For arbitrary LLM provider specific output.
|
||||||
@ -43,9 +42,8 @@ class LLMResult(BaseModel):
|
|||||||
This dictionary is a free-form dictionary that can contain any information that the
|
This dictionary is a free-form dictionary that can contain any information that the
|
||||||
provider wants to return. It is not standardized and is provider-specific.
|
provider wants to return. It is not standardized and is provider-specific.
|
||||||
|
|
||||||
Users should generally avoid relying on this field and instead rely on
|
Users should generally avoid relying on this field and instead rely on accessing
|
||||||
accessing relevant information from standardized fields present in
|
relevant information from standardized fields present in AIMessage.
|
||||||
AIMessage.
|
|
||||||
"""
|
"""
|
||||||
run: Optional[list[RunInfo]] = None
|
run: Optional[list[RunInfo]] = None
|
||||||
"""List of metadata info for model call for each input."""
|
"""List of metadata info for model call for each input."""
|
||||||
|
@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only(
|
|||||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
|
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||||
|
*, use_tool_calls: bool
|
||||||
|
) -> None:
|
||||||
|
# Test case from the original bug report
|
||||||
|
def create_message() -> AIMessage:
|
||||||
|
tool_calls_data = [
|
||||||
|
{
|
||||||
|
"id": "call_other",
|
||||||
|
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_func",
|
||||||
|
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_tool_calls:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||||
|
{"id": "call_func", "name": "func", "args": {"a": 1}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = [ChatGeneration(message=create_message())]
|
||||||
|
|
||||||
|
# Test with return_id=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return the func tool call, not None
|
||||||
|
assert output is not None
|
||||||
|
assert output["type"] == "func"
|
||||||
|
assert output["args"] == {"a": 1}
|
||||||
|
assert "id" in output
|
||||||
|
|
||||||
|
# Test with return_id=False
|
||||||
|
parser_no_id = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return just the args
|
||||||
|
assert output_no_id == {"a": 1}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
|
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||||
|
*, use_tool_calls: bool
|
||||||
|
) -> None:
|
||||||
|
def create_message() -> AIMessage:
|
||||||
|
tool_calls_data = [
|
||||||
|
{
|
||||||
|
"id": "call_other",
|
||||||
|
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_another",
|
||||||
|
"function": {"name": "another", "arguments": '{"c":3}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_tool_calls:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||||
|
{"id": "call_another", "name": "another", "args": {"c": 3}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = [ChatGeneration(message=create_message())]
|
||||||
|
|
||||||
|
# Test with return_id=True, first_tool_only=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return None when no matches
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
# Test with return_id=False, first_tool_only=True
|
||||||
|
parser_no_id = JsonOutputKeyToolsParser(
|
||||||
|
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return None when no matches
|
||||||
|
assert output_no_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
|
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||||
|
*, use_tool_calls: bool
|
||||||
|
) -> None:
|
||||||
|
def create_message() -> AIMessage:
|
||||||
|
tool_calls_data = [
|
||||||
|
{
|
||||||
|
"id": "call_func1",
|
||||||
|
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_other",
|
||||||
|
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_func2",
|
||||||
|
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_tool_calls:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||||
|
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||||
|
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = [ChatGeneration(message=create_message())]
|
||||||
|
|
||||||
|
# Test with first_tool_only=True - should return first matching
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert output is not None
|
||||||
|
assert output["type"] == "func"
|
||||||
|
assert output["args"] == {"a": 1} # First matching tool call
|
||||||
|
|
||||||
|
# Test with first_tool_only=False - should return all matching
|
||||||
|
parser_all = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert len(output_all) == 2
|
||||||
|
assert output_all[0]["args"] == {"a": 1}
|
||||||
|
assert output_all[1]["args"] == {"a": 3}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
|
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||||
|
def create_message() -> AIMessage:
|
||||||
|
if use_tool_calls:
|
||||||
|
return AIMessage(content="", tool_calls=[])
|
||||||
|
return AIMessage(content="", additional_kwargs={"tool_calls": []})
|
||||||
|
|
||||||
|
result = [ChatGeneration(message=create_message())]
|
||||||
|
|
||||||
|
# Test with first_tool_only=True
|
||||||
|
parser = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return None for empty results
|
||||||
|
assert output is None
|
||||||
|
|
||||||
|
# Test with first_tool_only=False
|
||||||
|
parser_all = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# Should return empty list for empty results
|
||||||
|
assert output_all == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||||
|
def test_json_output_key_tools_parser_parameter_combinations(
|
||||||
|
*, use_tool_calls: bool
|
||||||
|
) -> None:
|
||||||
|
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||||
|
|
||||||
|
def create_message() -> AIMessage:
|
||||||
|
tool_calls_data = [
|
||||||
|
{
|
||||||
|
"id": "call_other",
|
||||||
|
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_func1",
|
||||||
|
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "call_func2",
|
||||||
|
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||||
|
"type": "function",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
if use_tool_calls:
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||||
|
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||||
|
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return AIMessage(
|
||||||
|
content="",
|
||||||
|
additional_kwargs={"tool_calls": tool_calls_data},
|
||||||
|
)
|
||||||
|
|
||||||
|
result: list[ChatGeneration] = [ChatGeneration(message=create_message())]
|
||||||
|
|
||||||
|
# Test: first_tool_only=True, return_id=True
|
||||||
|
parser1 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=True
|
||||||
|
)
|
||||||
|
output1 = parser1.parse_result(result) # type: ignore[arg-type]
|
||||||
|
assert output1["type"] == "func"
|
||||||
|
assert output1["args"] == {"a": 1}
|
||||||
|
assert "id" in output1
|
||||||
|
|
||||||
|
# Test: first_tool_only=True, return_id=False
|
||||||
|
parser2 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=True, return_id=False
|
||||||
|
)
|
||||||
|
output2 = parser2.parse_result(result) # type: ignore[arg-type]
|
||||||
|
assert output2 == {"a": 1}
|
||||||
|
|
||||||
|
# Test: first_tool_only=False, return_id=True
|
||||||
|
parser3 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=True
|
||||||
|
)
|
||||||
|
output3 = parser3.parse_result(result) # type: ignore[arg-type]
|
||||||
|
assert len(output3) == 2
|
||||||
|
assert all("id" in item for item in output3)
|
||||||
|
assert output3[0]["args"] == {"a": 1}
|
||||||
|
assert output3[1]["args"] == {"a": 3}
|
||||||
|
|
||||||
|
# Test: first_tool_only=False, return_id=False
|
||||||
|
parser4 = JsonOutputKeyToolsParser(
|
||||||
|
key_name="func", first_tool_only=False, return_id=False
|
||||||
|
)
|
||||||
|
output4 = parser4.parse_result(result) # type: ignore[arg-type]
|
||||||
|
assert output4 == [{"a": 1}, {"a": 3}]
|
||||||
|
|
||||||
|
|
||||||
class Person(BaseModel):
|
class Person(BaseModel):
|
||||||
age: int
|
age: int
|
||||||
hair_color: str
|
hair_color: str
|
||||||
|
Loading…
Reference in New Issue
Block a user