diff --git a/libs/core/langchain_core/output_parsers/openai_tools.py b/libs/core/langchain_core/output_parsers/openai_tools.py index b4f845bdcad..63495bc2d84 100644 --- a/libs/core/langchain_core/output_parsers/openai_tools.py +++ b/libs/core/langchain_core/output_parsers/openai_tools.py @@ -234,12 +234,39 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): Returns: The parsed tool calls. """ - parsed_result = super().parse_result(result, partial=partial) - + generation = result[0] + if not isinstance(generation, ChatGeneration): + msg = "This output parser can only be used with a chat generation." + raise OutputParserException(msg) + message = generation.message + if isinstance(message, AIMessage) and message.tool_calls: + parsed_tool_calls = [dict(tc) for tc in message.tool_calls] + for tool_call in parsed_tool_calls: + if not self.return_id: + _ = tool_call.pop("id") + else: + try: + raw_tool_calls = copy.deepcopy(message.additional_kwargs["tool_calls"]) + except KeyError: + if self.first_tool_only: + return None + return [] + parsed_tool_calls = parse_tool_calls( + raw_tool_calls, + partial=partial, + strict=self.strict, + return_id=self.return_id, + ) + # For backwards compatibility + for tc in parsed_tool_calls: + tc["type"] = tc.pop("name") if self.first_tool_only: + parsed_result = list( + filter(lambda x: x["type"] == self.key_name, parsed_tool_calls) + ) single_result = ( - parsed_result - if parsed_result and parsed_result["type"] == self.key_name + parsed_result[0] + if parsed_result and parsed_result[0]["type"] == self.key_name else None ) if self.return_id: @@ -247,10 +274,13 @@ class JsonOutputKeyToolsParser(JsonOutputToolsParser): if single_result: return single_result["args"] return None - parsed_result = [res for res in parsed_result if res["type"] == self.key_name] - if not self.return_id: - parsed_result = [res["args"] for res in parsed_result] - return parsed_result + return ( + [res for res in parsed_tool_calls if res["type"] == self.key_name] + if self.return_id + else [ + res["args"] for res in parsed_tool_calls if res["type"] == self.key_name + ] + ) # Common cause of ValidationError is truncated output due to max_tokens. diff --git a/libs/core/langchain_core/outputs/__init__.py b/libs/core/langchain_core/outputs/__init__.py index b9072b9c929..1e64d4a83aa 100644 --- a/libs/core/langchain_core/outputs/__init__.py +++ b/libs/core/langchain_core/outputs/__init__.py @@ -1,24 +1,23 @@ """Output classes. -**Output** classes are used to represent the output of a language model call -and the output of a chat. +Used to represent the output of a language model call and the output of a chat. -The top container for information is the `LLMResult` object. `LLMResult` is used by -both chat models and LLMs. This object contains the output of the language -model and any additional information that the model provider wants to return. +The top container for information is the `LLMResult` object. `LLMResult` is used by both +chat models and LLMs. This object contains the output of the language model and any +additional information that the model provider wants to return. When invoking models via the standard runnable methods (e.g. invoke, batch, etc.): + - Chat models will return `AIMessage` objects. - LLMs will return regular text strings. In addition, users can access the raw output of either LLMs or chat models via -callbacks. The on_chat_model_end and on_llm_end callbacks will return an +callbacks. The ``on_chat_model_end`` and ``on_llm_end`` callbacks will return an LLMResult object containing the generated outputs and any additional information returned by the model provider. -In general, if information is already available -in the AIMessage object, it is recommended to access it from there rather than -from the `LLMResult` object. +In general, if information is already available in the AIMessage object, it is +recommended to access it from there rather than from the `LLMResult` object. """ from typing import TYPE_CHECKING diff --git a/libs/core/langchain_core/outputs/chat_generation.py b/libs/core/langchain_core/outputs/chat_generation.py index c36ae18a4ad..d42f2038d34 100644 --- a/libs/core/langchain_core/outputs/chat_generation.py +++ b/libs/core/langchain_core/outputs/chat_generation.py @@ -27,7 +27,11 @@ class ChatGeneration(Generation): """ text: str = "" - """*SHOULD NOT BE SET DIRECTLY* The text contents of the output message.""" + """The text contents of the output message. + + .. warning:: + SHOULD NOT BE SET DIRECTLY! + """ message: BaseMessage """The message output by the chat model.""" # Override type to be ChatGeneration, ignore mypy error as this is intentional diff --git a/libs/core/langchain_core/outputs/generation.py b/libs/core/langchain_core/outputs/generation.py index 8f3bbe5a77c..1167616d6bc 100644 --- a/libs/core/langchain_core/outputs/generation.py +++ b/libs/core/langchain_core/outputs/generation.py @@ -11,7 +11,8 @@ from langchain_core.utils._merge import merge_dicts class Generation(Serializable): """A single text generation output. - Generation represents the response from an "old-fashioned" LLM that + Generation represents the response from an + `"old-fashioned" LLM __` that generates regular text (not chat messages). This model is used internally by chat model and will eventually diff --git a/libs/core/langchain_core/outputs/llm_result.py b/libs/core/langchain_core/outputs/llm_result.py index 71ff7b807ec..3e38687022a 100644 --- a/libs/core/langchain_core/outputs/llm_result.py +++ b/libs/core/langchain_core/outputs/llm_result.py @@ -15,9 +15,9 @@ from langchain_core.outputs.run_info import RunInfo class LLMResult(BaseModel): """A container for results of an LLM call. - Both chat models and LLMs generate an LLMResult object. This object contains - the generated outputs and any additional information that the model provider - wants to return. + Both chat models and LLMs generate an LLMResult object. This object contains the + generated outputs and any additional information that the model provider wants to + return. """ generations: list[ @@ -25,17 +25,16 @@ class LLMResult(BaseModel): ] """Generated outputs. - The first dimension of the list represents completions for different input - prompts. + The first dimension of the list represents completions for different input prompts. - The second dimension of the list represents different candidate generations - for a given prompt. + The second dimension of the list represents different candidate generations for a + given prompt. - When returned from an LLM the type is list[list[Generation]]. - When returned from a chat model the type is list[list[ChatGeneration]]. + - When returned from **an LLM**, the type is ``list[list[Generation]]``. + - When returned from a **chat model**, the type is ``list[list[ChatGeneration]]``. - ChatGeneration is a subclass of Generation that has a field for a structured - chat message. + ChatGeneration is a subclass of Generation that has a field for a structured chat + message. """ llm_output: Optional[dict] = None """For arbitrary LLM provider specific output. @@ -43,9 +42,8 @@ class LLMResult(BaseModel): This dictionary is a free-form dictionary that can contain any information that the provider wants to return. It is not standardized and is provider-specific. - Users should generally avoid relying on this field and instead rely on - accessing relevant information from standardized fields present in - AIMessage. + Users should generally avoid relying on this field and instead rely on accessing + relevant information from standardized fields present in AIMessage. """ run: Optional[list[RunInfo]] = None """List of metadata info for model call for each input.""" diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 3fbb65c63b0..74862a8386a 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -475,6 +475,277 @@ async def test_partial_json_output_parser_key_async_first_only( assert [p async for p in chain.astream(None)] == EXPECTED_STREAMED_JSON +@pytest.mark.parametrize("use_tool_calls", [False, True]) +def test_json_output_key_tools_parser_multiple_tools_first_only( + *, use_tool_calls: bool +) -> None: + # Test case from the original bug report + def create_message() -> AIMessage: + tool_calls_data = [ + { + "id": "call_other", + "function": {"name": "other", "arguments": '{"b":2}'}, + "type": "function", + }, + { + "id": "call_func", + "function": {"name": "func", "arguments": '{"a":1}'}, + "type": "function", + }, + ] + + if use_tool_calls: + return AIMessage( + content="", + tool_calls=[ + {"id": "call_other", "name": "other", "args": {"b": 2}}, + {"id": "call_func", "name": "func", "args": {"a": 1}}, + ], + ) + return AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls_data}, + ) + + result = [ChatGeneration(message=create_message())] + + # Test with return_id=True + parser = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=True + ) + output = parser.parse_result(result) # type: ignore[arg-type] + + # Should return the func tool call, not None + assert output is not None + assert output["type"] == "func" + assert output["args"] == {"a": 1} + assert "id" in output + + # Test with return_id=False + parser_no_id = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=False + ) + output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type] + + # Should return just the args + assert output_no_id == {"a": 1} + + +@pytest.mark.parametrize("use_tool_calls", [False, True]) +def test_json_output_key_tools_parser_multiple_tools_no_match( + *, use_tool_calls: bool +) -> None: + def create_message() -> AIMessage: + tool_calls_data = [ + { + "id": "call_other", + "function": {"name": "other", "arguments": '{"b":2}'}, + "type": "function", + }, + { + "id": "call_another", + "function": {"name": "another", "arguments": '{"c":3}'}, + "type": "function", + }, + ] + + if use_tool_calls: + return AIMessage( + content="", + tool_calls=[ + {"id": "call_other", "name": "other", "args": {"b": 2}}, + {"id": "call_another", "name": "another", "args": {"c": 3}}, + ], + ) + return AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls_data}, + ) + + result = [ChatGeneration(message=create_message())] + + # Test with return_id=True, first_tool_only=True + parser = JsonOutputKeyToolsParser( + key_name="nonexistent", first_tool_only=True, return_id=True + ) + output = parser.parse_result(result) # type: ignore[arg-type] + + # Should return None when no matches + assert output is None + + # Test with return_id=False, first_tool_only=True + parser_no_id = JsonOutputKeyToolsParser( + key_name="nonexistent", first_tool_only=True, return_id=False + ) + output_no_id = parser_no_id.parse_result(result) # type: ignore[arg-type] + + # Should return None when no matches + assert output_no_id is None + + +@pytest.mark.parametrize("use_tool_calls", [False, True]) +def test_json_output_key_tools_parser_multiple_matching_tools( + *, use_tool_calls: bool +) -> None: + def create_message() -> AIMessage: + tool_calls_data = [ + { + "id": "call_func1", + "function": {"name": "func", "arguments": '{"a":1}'}, + "type": "function", + }, + { + "id": "call_other", + "function": {"name": "other", "arguments": '{"b":2}'}, + "type": "function", + }, + { + "id": "call_func2", + "function": {"name": "func", "arguments": '{"a":3}'}, + "type": "function", + }, + ] + + if use_tool_calls: + return AIMessage( + content="", + tool_calls=[ + {"id": "call_func1", "name": "func", "args": {"a": 1}}, + {"id": "call_other", "name": "other", "args": {"b": 2}}, + {"id": "call_func2", "name": "func", "args": {"a": 3}}, + ], + ) + return AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls_data}, + ) + + result = [ChatGeneration(message=create_message())] + + # Test with first_tool_only=True - should return first matching + parser = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=True + ) + output = parser.parse_result(result) # type: ignore[arg-type] + + assert output is not None + assert output["type"] == "func" + assert output["args"] == {"a": 1} # First matching tool call + + # Test with first_tool_only=False - should return all matching + parser_all = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=False, return_id=True + ) + output_all = parser_all.parse_result(result) # type: ignore[arg-type] + + assert len(output_all) == 2 + assert output_all[0]["args"] == {"a": 1} + assert output_all[1]["args"] == {"a": 3} + + +@pytest.mark.parametrize("use_tool_calls", [False, True]) +def test_json_output_key_tools_parser_empty_results(*, use_tool_calls: bool) -> None: + def create_message() -> AIMessage: + if use_tool_calls: + return AIMessage(content="", tool_calls=[]) + return AIMessage(content="", additional_kwargs={"tool_calls": []}) + + result = [ChatGeneration(message=create_message())] + + # Test with first_tool_only=True + parser = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=True + ) + output = parser.parse_result(result) # type: ignore[arg-type] + + # Should return None for empty results + assert output is None + + # Test with first_tool_only=False + parser_all = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=False, return_id=True + ) + output_all = parser_all.parse_result(result) # type: ignore[arg-type] + + # Should return empty list for empty results + assert output_all == [] + + +@pytest.mark.parametrize("use_tool_calls", [False, True]) +def test_json_output_key_tools_parser_parameter_combinations( + *, use_tool_calls: bool +) -> None: + """Test all parameter combinations of JsonOutputKeyToolsParser.""" + + def create_message() -> AIMessage: + tool_calls_data = [ + { + "id": "call_other", + "function": {"name": "other", "arguments": '{"b":2}'}, + "type": "function", + }, + { + "id": "call_func1", + "function": {"name": "func", "arguments": '{"a":1}'}, + "type": "function", + }, + { + "id": "call_func2", + "function": {"name": "func", "arguments": '{"a":3}'}, + "type": "function", + }, + ] + + if use_tool_calls: + return AIMessage( + content="", + tool_calls=[ + {"id": "call_other", "name": "other", "args": {"b": 2}}, + {"id": "call_func1", "name": "func", "args": {"a": 1}}, + {"id": "call_func2", "name": "func", "args": {"a": 3}}, + ], + ) + return AIMessage( + content="", + additional_kwargs={"tool_calls": tool_calls_data}, + ) + + result: list[ChatGeneration] = [ChatGeneration(message=create_message())] + + # Test: first_tool_only=True, return_id=True + parser1 = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=True + ) + output1 = parser1.parse_result(result) # type: ignore[arg-type] + assert output1["type"] == "func" + assert output1["args"] == {"a": 1} + assert "id" in output1 + + # Test: first_tool_only=True, return_id=False + parser2 = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=True, return_id=False + ) + output2 = parser2.parse_result(result) # type: ignore[arg-type] + assert output2 == {"a": 1} + + # Test: first_tool_only=False, return_id=True + parser3 = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=False, return_id=True + ) + output3 = parser3.parse_result(result) # type: ignore[arg-type] + assert len(output3) == 2 + assert all("id" in item for item in output3) + assert output3[0]["args"] == {"a": 1} + assert output3[1]["args"] == {"a": 3} + + # Test: first_tool_only=False, return_id=False + parser4 = JsonOutputKeyToolsParser( + key_name="func", first_tool_only=False, return_id=False + ) + output4 = parser4.parse_result(result) # type: ignore[arg-type] + assert output4 == [{"a": 1}, {"a": 3}] + + class Person(BaseModel): age: int hair_color: str