anthropic[patch]: tool output parser fix (#23647)

This commit is contained in:
Bagatur 2024-07-02 16:33:22 -04:00 committed by GitHub
parent 46cbf0e4aa
commit 7a6c06cadd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 20 deletions

View File

@ -1,6 +1,6 @@
from typing import Any, List, Optional, Type from typing import Any, List, Optional, Type, Union, cast
from langchain_core.messages import ToolCall from langchain_core.messages import AIMessage, ToolCall
from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
@ -31,19 +31,18 @@ class ToolsOutputParser(BaseGenerationOutputParser):
""" """
if not result or not isinstance(result[0], ChatGeneration): if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else [] return None if self.first_tool_only else []
message = result[0].message message = cast(AIMessage, result[0].message)
if isinstance(message.content, str): tool_calls: List = [
tool_calls: List = [] dict(tc) for tc in _extract_tool_calls_from_message(message)
else: ]
content: List = message.content if isinstance(message.content, list):
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
# Map tool call id to index # Map tool call id to index
id_to_index = { id_to_index = {
block["id"]: i block["id"]: i
for i, block in enumerate(content) for i, block in enumerate(message.content)
if block["type"] == "tool_use" if isinstance(block, dict) and block["type"] == "tool_use"
} }
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls] tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in tool_calls]
if self.pydantic_schemas: if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls] tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only: elif self.args_only:
@ -63,13 +62,25 @@ class ToolsOutputParser(BaseGenerationOutputParser):
return cls_(**tool_call["args"]) return cls_(**tool_call["args"])
def extract_tool_calls(content: List[dict]) -> List[ToolCall]: def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]:
"""Extract tool calls from a list of content blocks.""" """Extract tool calls from a list of content blocks."""
tool_calls = [] if message.tool_calls:
for block in content: return message.tool_calls
if block["type"] != "tool_use": return extract_tool_calls(message.content)
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"]) def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
) """Extract tool calls from a list of content blocks."""
return tool_calls if isinstance(content, list):
tool_calls = []
for block in content:
if isinstance(block, str):
continue
if block["type"] != "tool_use":
continue
tool_calls.append(
ToolCall(name=block["name"], args=block["input"], id=block["id"])
)
return tool_calls
else:
return []

View File

@ -70,3 +70,19 @@ def test_tools_output_parser_pydantic() -> None:
expected = [_Foo1(bar=0), _Foo2(baz="a")] expected = [_Foo1(bar=0), _Foo2(baz="a")]
actual = output_parser.parse_result(_RESULT) actual = output_parser.parse_result(_RESULT)
assert expected == actual assert expected == actual
def test_tools_output_parser_empty_content() -> None:
class ChartType(BaseModel):
chart_type: Literal["pie", "line", "bar"]
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[ChartType]
)
message = AIMessage(
"",
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}],
)
actual = output_parser.invoke(message)
expected = ChartType(chart_type="pie")
assert expected == actual