diff --git a/libs/partners/anthropic/langchain_anthropic/output_parsers.py b/libs/partners/anthropic/langchain_anthropic/output_parsers.py index fa8d3bd4191..c8d6aa3aeec 100644 --- a/libs/partners/anthropic/langchain_anthropic/output_parsers.py +++ b/libs/partners/anthropic/langchain_anthropic/output_parsers.py @@ -1,6 +1,6 @@ -from typing import Any, List, Optional, Type +from typing import Any, List, Optional, Type, Union, cast -from langchain_core.messages import ToolCall +from langchain_core.messages import AIMessage, ToolCall from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.outputs import ChatGeneration, Generation from langchain_core.pydantic_v1 import BaseModel @@ -31,19 +31,18 @@ class ToolsOutputParser(BaseGenerationOutputParser): """ if not result or not isinstance(result[0], ChatGeneration): return None if self.first_tool_only else [] - message = result[0].message - if isinstance(message.content, str): - tool_calls: List = [] - else: - content: List = message.content - _tool_calls = [dict(tc) for tc in extract_tool_calls(content)] + message = cast(AIMessage, result[0].message) + tool_calls: List = [ + dict(tc) for tc in _extract_tool_calls_from_message(message) + ] + if isinstance(message.content, list): # Map tool call id to index id_to_index = { block["id"]: i - for i, block in enumerate(content) - if block["type"] == "tool_use" + for i, block in enumerate(message.content) + if isinstance(block, dict) and block["type"] == "tool_use" } - tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] + tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in tool_calls] if self.pydantic_schemas: tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] elif self.args_only: @@ -63,13 +62,25 @@ class ToolsOutputParser(BaseGenerationOutputParser): return cls_(**tool_call["args"]) -def extract_tool_calls(content: List[dict]) -> List[ToolCall]: +def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]: """Extract tool calls from a list of content blocks.""" - tool_calls = [] - for block in content: - if block["type"] != "tool_use": - continue - tool_calls.append( - ToolCall(name=block["name"], args=block["input"], id=block["id"]) - ) - return tool_calls + if message.tool_calls: + return message.tool_calls + return extract_tool_calls(message.content) + + +def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]: + """Extract tool calls from a list of content blocks.""" + if isinstance(content, list): + tool_calls = [] + for block in content: + if isinstance(block, str): + continue + if block["type"] != "tool_use": + continue + tool_calls.append( + ToolCall(name=block["name"], args=block["input"], id=block["id"]) + ) + return tool_calls + else: + return [] diff --git a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py index 8f8814b3445..a163702e544 100644 --- a/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py +++ b/libs/partners/anthropic/tests/unit_tests/test_output_parsers.py @@ -70,3 +70,19 @@ def test_tools_output_parser_pydantic() -> None: expected = [_Foo1(bar=0), _Foo2(baz="a")] actual = output_parser.parse_result(_RESULT) assert expected == actual + + +def test_tools_output_parser_empty_content() -> None: + class ChartType(BaseModel): + chart_type: Literal["pie", "line", "bar"] + + output_parser = ToolsOutputParser( + first_tool_only=True, pydantic_schemas=[ChartType] + ) + message = AIMessage( + "", + tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}], + ) + actual = output_parser.invoke(message) + expected = ChartType(chart_type="pie") + assert expected == actual