mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-13 08:27:03 +00:00
anthropic[patch]: tool output parser fix (#23647)
This commit is contained in:
parent
46cbf0e4aa
commit
7a6c06cadd
@ -1,6 +1,6 @@
|
|||||||
from typing import Any, List, Optional, Type
|
from typing import Any, List, Optional, Type, Union, cast
|
||||||
|
|
||||||
from langchain_core.messages import ToolCall
|
from langchain_core.messages import AIMessage, ToolCall
|
||||||
from langchain_core.output_parsers import BaseGenerationOutputParser
|
from langchain_core.output_parsers import BaseGenerationOutputParser
|
||||||
from langchain_core.outputs import ChatGeneration, Generation
|
from langchain_core.outputs import ChatGeneration, Generation
|
||||||
from langchain_core.pydantic_v1 import BaseModel
|
from langchain_core.pydantic_v1 import BaseModel
|
||||||
@ -31,19 +31,18 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
|||||||
"""
|
"""
|
||||||
if not result or not isinstance(result[0], ChatGeneration):
|
if not result or not isinstance(result[0], ChatGeneration):
|
||||||
return None if self.first_tool_only else []
|
return None if self.first_tool_only else []
|
||||||
message = result[0].message
|
message = cast(AIMessage, result[0].message)
|
||||||
if isinstance(message.content, str):
|
tool_calls: List = [
|
||||||
tool_calls: List = []
|
dict(tc) for tc in _extract_tool_calls_from_message(message)
|
||||||
else:
|
]
|
||||||
content: List = message.content
|
if isinstance(message.content, list):
|
||||||
_tool_calls = [dict(tc) for tc in extract_tool_calls(content)]
|
|
||||||
# Map tool call id to index
|
# Map tool call id to index
|
||||||
id_to_index = {
|
id_to_index = {
|
||||||
block["id"]: i
|
block["id"]: i
|
||||||
for i, block in enumerate(content)
|
for i, block in enumerate(message.content)
|
||||||
if block["type"] == "tool_use"
|
if isinstance(block, dict) and block["type"] == "tool_use"
|
||||||
}
|
}
|
||||||
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in _tool_calls]
|
tool_calls = [{**tc, "index": id_to_index[tc["id"]]} for tc in tool_calls]
|
||||||
if self.pydantic_schemas:
|
if self.pydantic_schemas:
|
||||||
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
|
||||||
elif self.args_only:
|
elif self.args_only:
|
||||||
@ -63,13 +62,25 @@ class ToolsOutputParser(BaseGenerationOutputParser):
|
|||||||
return cls_(**tool_call["args"])
|
return cls_(**tool_call["args"])
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_calls(content: List[dict]) -> List[ToolCall]:
|
def _extract_tool_calls_from_message(message: AIMessage) -> List[ToolCall]:
|
||||||
"""Extract tool calls from a list of content blocks."""
|
"""Extract tool calls from a list of content blocks."""
|
||||||
tool_calls = []
|
if message.tool_calls:
|
||||||
for block in content:
|
return message.tool_calls
|
||||||
if block["type"] != "tool_use":
|
return extract_tool_calls(message.content)
|
||||||
continue
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
def extract_tool_calls(content: Union[str, List[Union[str, dict]]]) -> List[ToolCall]:
|
||||||
)
|
"""Extract tool calls from a list of content blocks."""
|
||||||
return tool_calls
|
if isinstance(content, list):
|
||||||
|
tool_calls = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
continue
|
||||||
|
if block["type"] != "tool_use":
|
||||||
|
continue
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(name=block["name"], args=block["input"], id=block["id"])
|
||||||
|
)
|
||||||
|
return tool_calls
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
@ -70,3 +70,19 @@ def test_tools_output_parser_pydantic() -> None:
|
|||||||
expected = [_Foo1(bar=0), _Foo2(baz="a")]
|
expected = [_Foo1(bar=0), _Foo2(baz="a")]
|
||||||
actual = output_parser.parse_result(_RESULT)
|
actual = output_parser.parse_result(_RESULT)
|
||||||
assert expected == actual
|
assert expected == actual
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_output_parser_empty_content() -> None:
|
||||||
|
class ChartType(BaseModel):
|
||||||
|
chart_type: Literal["pie", "line", "bar"]
|
||||||
|
|
||||||
|
output_parser = ToolsOutputParser(
|
||||||
|
first_tool_only=True, pydantic_schemas=[ChartType]
|
||||||
|
)
|
||||||
|
message = AIMessage(
|
||||||
|
"",
|
||||||
|
tool_calls=[{"name": "ChartType", "args": {"chart_type": "pie"}, "id": "foo"}],
|
||||||
|
)
|
||||||
|
actual = output_parser.invoke(message)
|
||||||
|
expected = ChartType(chart_type="pie")
|
||||||
|
assert expected == actual
|
||||||
|
Loading…
Reference in New Issue
Block a user