mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
fix(core): fix parse_result
in case of self.first_tool_only with multiple keys matching for JsonOutputKeyToolsParser (#32106)
* **Description:** Updated `parse_result` logic to handle cases where `self.first_tool_only` is `True` and multiple matching keys share the same function name. Instead of returning the first match prematurely, the method now prioritizes filtering results by the specified key to ensure correct selection. * **Issue:** #32100 --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
ddaba21e83
commit
095f4a7c28
@ -234,12 +234,39 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
Returns:
|
||||
The parsed tool calls.
|
||||
"""
|
||||
parsed_result = super().parse_result(result, partial=partial)
|
||||
|
||||
generation = result[0]
|
||||
if not isinstance(generation, ChatGeneration):
|
||||
msg = "This output parser can only be used with a chat generation."
|
||||
raise OutputParserException(msg)
|
||||
message = generation.message
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
parsed_tool_calls = [dict(tc) for tc in message.tool_calls]
|
||||
for tool_call in parsed_tool_calls:
|
||||
if not self.return_id:
|
||||
_ = tool_call.pop("id")
|
||||
else:
|
||||
try:
|
||||
raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"])
|
||||
except KeyError:
|
||||
if self.first_tool_only:
|
||||
return None
|
||||
return []
|
||||
parsed_tool_calls = parse_tool_calls(
|
||||
raw_tool_calls,
|
||||
partial=partial,
|
||||
strict=self.strict,
|
||||
return_id=self.return_id,
|
||||
)
|
||||
# For backwards compatibility
|
||||
for tc in parsed_tool_calls:
|
||||
tc["type"] = tc.pop("name")
|
||||
if self.first_tool_only:
|
||||
parsed_result = list(
|
||||
filter(lambda x: x["type"] == self.key_name, parsed_tool_calls)
|
||||
)
|
||||
single_result = (
|
||||
parsed_result
|
||||
if parsed_result and parsed_result["type"] == self.key_name
|
||||
parsed_result[0]
|
||||
if parsed_result and parsed_result[0]["type"] == self.key_name
|
||||
else None
|
||||
)
|
||||
if self.return_id:
|
||||
@ -247,10 +274,13 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser):
|
||||
if single_result:
|
||||
return single_result["args"]
|
||||
return None
|
||||
parsed_result = [res for res in parsed_result if res["type"] == self.key_name]
|
||||
if not self.return_id:
|
||||
parsed_result = [res["args"] for res in parsed_result]
|
||||
return parsed_result
|
||||
return (
|
||||
[res for res in parsed_tool_calls if res["type"] == self.key_name]
|
||||
if self.return_id
|
||||
else [
|
||||
res["args"] for res in parsed_tool_calls if res["type"] == self.key_name
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# Common cause of ValidationError is truncated output due to max_tokens.
|
||||
|
@ -1,24 +1,23 @@
|
||||
"""Output classes.
|
||||
|
||||
**Output** classes are used to represent the output of a language model call
|
||||
and the output of a chat.
|
||||
Used to represent the output of a language model call and the output of a chat.
|
||||
|
||||
The top container for information is the `LLMResult` object. `LLMResult` is used by
|
||||
both chat models and LLMs. This object contains the output of the language
|
||||
model and any additional information that the model provider wants to return.
|
||||
The top container for information is the `LLMResult` object. `LLMResult` is used by both
|
||||
chat models and LLMs. This object contains the output of the language model and any
|
||||
additional information that the model provider wants to return.
|
||||
|
||||
When invoking models via the standard runnable methods (e.g. invoke, batch, etc.):
|
||||
|
||||
- Chat models will return `AIMessage` objects.
|
||||
- LLMs will return regular text strings.
|
||||
|
||||
In addition, users can access the raw output of either LLMs or chat models via
|
||||
callbacks. The on_chat_model_end and on_llm_end callbacks will return an
|
||||
callbacks. The ``on_chat_model_end`` and ``on_llm_end`` callbacks will return an
|
||||
LLMResult object containing the generated outputs and any additional information
|
||||
returned by the model provider.
|
||||
|
||||
In general, if information is already available
|
||||
in the AIMessage object, it is recommended to access it from there rather than
|
||||
from the `LLMResult` object.
|
||||
In general, if information is already available in the AIMessage object, it is
|
||||
recommended to access it from there rather than from the `LLMResult` object.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
@ -27,7 +27,11 @@ class ChatGeneration(Generation):
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
"""*SHOULD NOT BE SET DIRECTLY* The text contents of the output message."""
|
||||
"""The text contents of the output message.
|
||||
|
||||
.. warning::
|
||||
SHOULD NOT BE SET DIRECTLY!
|
||||
"""
|
||||
message: BaseMessage
|
||||
"""The message output by the chat model."""
|
||||
# Override type to be ChatGeneration, ignore mypy error as this is intentional
|
||||
|
@ -11,7 +11,8 @@ from langchain_core.utils._merge import merge_dicts
|
||||
class Generation(Serializable):
|
||||
"""A single text generation output.
|
||||
|
||||
Generation represents the response from an "old-fashioned" LLM that
|
||||
Generation represents the response from an
|
||||
`"old-fashioned" LLM <https://python.langchain.com/docs/concepts/text_llms/>__` that
|
||||
generates regular text (not chat messages).
|
||||
|
||||
This model is used internally by chat model and will eventually
|
||||
|
@ -15,9 +15,9 @@ from langchain_core.outputs.run_info import RunInfo
|
||||
class LLMResult(BaseModel):
|
||||
"""A container for results of an LLM call.
|
||||
|
||||
Both chat models and LLMs generate an LLMResult object. This object contains
|
||||
the generated outputs and any additional information that the model provider
|
||||
wants to return.
|
||||
Both chat models and LLMs generate an LLMResult object. This object contains the
|
||||
generated outputs and any additional information that the model provider wants to
|
||||
return.
|
||||
"""
|
||||
|
||||
generations: list[
|
||||
@ -25,17 +25,16 @@ class LLMResult(BaseModel):
|
||||
]
|
||||
"""Generated outputs.
|
||||
|
||||
The first dimension of the list represents completions for different input
|
||||
prompts.
|
||||
The first dimension of the list represents completions for different input prompts.
|
||||
|
||||
The second dimension of the list represents different candidate generations
|
||||
for a given prompt.
|
||||
The second dimension of the list represents different candidate generations for a
|
||||
given prompt.
|
||||
|
||||
When returned from an LLM the type is list[list[Generation]].
|
||||
When returned from a chat model the type is list[list[ChatGeneration]].
|
||||
- When returned from **an LLM**, the type is ``list[list[Generation]]``.
|
||||
- When returned from a **chat model**, the type is ``list[list[ChatGeneration]]``.
|
||||
|
||||
ChatGeneration is a subclass of Generation that has a field for a structured
|
||||
chat message.
|
||||
ChatGeneration is a subclass of Generation that has a field for a structured chat
|
||||
message.
|
||||
"""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output.
|
||||
@ -43,9 +42,8 @@ class LLMResult(BaseModel):
|
||||
This dictionary is a free-form dictionary that can contain any information that the
|
||||
provider wants to return. It is not standardized and is provider-specific.
|
||||
|
||||
Users should generally avoid relying on this field and instead rely on
|
||||
accessing relevant information from standardized fields present in
|
||||
AIMessage.
|
||||
Users should generally avoid relying on this field and instead rely on accessing
|
||||
relevant information from standardized fields present in AIMessage.
|
||||
"""
|
||||
run: Optional[list[RunInfo]] = None
|
||||
"""List of metadata info for model call for each input."""
|
||||
|
@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only(
|
||||
assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_first_only(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
# Test case from the original bug report
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func", "name": "func", "args": {"a": 1}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with return_id=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return the func tool call, not None
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1}
|
||||
assert "id" in output
|
||||
|
||||
# Test with return_id=False
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return just the args
|
||||
assert output_no_id == {"a": 1}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_tools_no_match(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_another",
|
||||
"function": {"name": "another", "arguments": '{"c":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_another", "name": "another", "args": {"c": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with return_id=True, first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None when no matches
|
||||
assert output is None
|
||||
|
||||
# Test with return_id=False, first_tool_only=True
|
||||
parser_no_id = JsonOutputKeyToolsParser(
|
||||
key_name="nonexistent", first_tool_only=True, return_id=False
|
||||
)
|
||||
output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None when no matches
|
||||
assert output_no_id is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_multiple_matching_tools(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_func1",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func2",
|
||||
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with first_tool_only=True - should return first matching
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
assert output is not None
|
||||
assert output["type"] == "func"
|
||||
assert output["args"] == {"a": 1} # First matching tool call
|
||||
|
||||
# Test with first_tool_only=False - should return all matching
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
assert len(output_all) == 2
|
||||
assert output_all[0]["args"] == {"a": 1}
|
||||
assert output_all[1]["args"] == {"a": 3}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None:
|
||||
def create_message() -> AIMessage:
|
||||
if use_tool_calls:
|
||||
return AIMessage(content="", tool_calls=[])
|
||||
return AIMessage(content="", additional_kwargs={"tool_calls": []})
|
||||
|
||||
result = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test with first_tool_only=True
|
||||
parser = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output = parser.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return None for empty results
|
||||
assert output is None
|
||||
|
||||
# Test with first_tool_only=False
|
||||
parser_all = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output_all = parser_all.parse_result(result) # type: ignore[arg-type]
|
||||
|
||||
# Should return empty list for empty results
|
||||
assert output_all == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_tool_calls", [False, True])
|
||||
def test_json_output_key_tools_parser_parameter_combinations(
|
||||
*, use_tool_calls: bool
|
||||
) -> None:
|
||||
"""Test all parameter combinations of JsonOutputKeyToolsParser."""
|
||||
|
||||
def create_message() -> AIMessage:
|
||||
tool_calls_data = [
|
||||
{
|
||||
"id": "call_other",
|
||||
"function": {"name": "other", "arguments": '{"b":2}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func1",
|
||||
"function": {"name": "func", "arguments": '{"a":1}'},
|
||||
"type": "function",
|
||||
},
|
||||
{
|
||||
"id": "call_func2",
|
||||
"function": {"name": "func", "arguments": '{"a":3}'},
|
||||
"type": "function",
|
||||
},
|
||||
]
|
||||
|
||||
if use_tool_calls:
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{"id": "call_other", "name": "other", "args": {"b": 2}},
|
||||
{"id": "call_func1", "name": "func", "args": {"a": 1}},
|
||||
{"id": "call_func2", "name": "func", "args": {"a": 3}},
|
||||
],
|
||||
)
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"tool_calls": tool_calls_data},
|
||||
)
|
||||
|
||||
result: list[ChatGeneration] = [ChatGeneration(message=create_message())]
|
||||
|
||||
# Test: first_tool_only=True, return_id=True
|
||||
parser1 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=True
|
||||
)
|
||||
output1 = parser1.parse_result(result) # type: ignore[arg-type]
|
||||
assert output1["type"] == "func"
|
||||
assert output1["args"] == {"a": 1}
|
||||
assert "id" in output1
|
||||
|
||||
# Test: first_tool_only=True, return_id=False
|
||||
parser2 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=True, return_id=False
|
||||
)
|
||||
output2 = parser2.parse_result(result) # type: ignore[arg-type]
|
||||
assert output2 == {"a": 1}
|
||||
|
||||
# Test: first_tool_only=False, return_id=True
|
||||
parser3 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=True
|
||||
)
|
||||
output3 = parser3.parse_result(result) # type: ignore[arg-type]
|
||||
assert len(output3) == 2
|
||||
assert all("id" in item for item in output3)
|
||||
assert output3[0]["args"] == {"a": 1}
|
||||
assert output3[1]["args"] == {"a": 3}
|
||||
|
||||
# Test: first_tool_only=False, return_id=False
|
||||
parser4 = JsonOutputKeyToolsParser(
|
||||
key_name="func", first_tool_only=False, return_id=False
|
||||
)
|
||||
output4 = parser4.parse_result(result) # type: ignore[arg-type]
|
||||
assert output4 == [{"a": 1}, {"a": 3}]
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
age: int
|
||||
hair_color: str
|
||||
|
Loading…
Reference in New Issue
Block a user