anthropic[minor]: tool use (#20016)

This commit is contained in:
Bagatur
2024-04-04 13:22:48 -07:00
committed by GitHub
parent 3aacd11846
commit 209de0a561
13 changed files with 1021 additions and 196 deletions

View File

@@ -13,6 +13,9 @@ class ToolMessage(BaseMessage):
tool_call_id: str
"""Tool call that this message is responding to."""
# TODO: Add is_error param?
# is_error: bool = False
# """Whether the tool errored."""
type: Literal["tool"] = "tool"

View File

@@ -1,13 +1,31 @@
import os
import re
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import warnings
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
cast,
)
import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
@@ -17,14 +35,26 @@ from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import (
Runnable,
RunnableMap,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
get_pydantic_field_names,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_anthropic.output_parsers import ToolsOutputParser
_message_type_lookups = {"human": "user", "ai": "assistant"}
@@ -56,6 +86,41 @@ def _format_image(image_url: str) -> Dict:
}
def _merge_messages(
messages: List[BaseMessage],
) -> List[Union[SystemMessage, AIMessage, HumanMessage]]:
"""Merge runs of human/tool messages into single human messages with content blocks.""" # noqa: E501
merged: list = []
for curr in messages:
if isinstance(curr, ToolMessage):
if isinstance(curr.content, str):
curr = HumanMessage(
[
{
"type": "tool_result",
"content": curr.content,
"tool_use_id": curr.tool_call_id,
}
]
)
else:
curr = HumanMessage(curr.content)
last = merged[-1] if merged else None
if isinstance(last, HumanMessage) and isinstance(curr, HumanMessage):
if isinstance(last.content, str):
new_content: List = [{"type": "text", "text": last.content}]
else:
new_content = last.content
if isinstance(curr.content, str):
new_content.append({"type": "text", "text": curr.content})
else:
new_content.extend(curr.content)
last.content = new_content
else:
merged.append(curr)
return merged
def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[Dict]]:
"""Format messages for anthropic."""
@@ -70,7 +135,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
"""
system: Optional[str] = None
formatted_messages: List[Dict] = []
for i, message in enumerate(messages):
merged_messages = _merge_messages(messages)
for i, message in enumerate(merged_messages):
if message.type == "system":
if i != 0:
raise ValueError("System message must be at beginning of message list.")
@@ -104,7 +171,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
elif isinstance(item, dict):
if "type" not in item:
raise ValueError("Dict content item must have a type key")
if item["type"] == "image_url":
elif item["type"] == "image_url":
# convert format
source = _format_image(item["image_url"]["url"])
content.append(
@@ -113,6 +180,9 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
"source": source,
}
)
elif item["type"] == "tool_use":
item.pop("text", None)
content.append(item)
else:
content.append(item)
else:
@@ -175,6 +245,9 @@ class ChatAnthropic(BaseChatModel):
anthropic_api_key: Optional[SecretStr] = None
default_headers: Optional[Mapping[str, str]] = None
"""Headers to pass to the Anthropic clients, will be used for every API call."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
streaming: bool = False
@@ -207,9 +280,15 @@ class ChatAnthropic(BaseChatModel):
or "https://api.anthropic.com"
)
values["anthropic_api_url"] = api_url
values["_client"] = anthropic.Client(api_key=api_key, base_url=api_url)
values["_client"] = anthropic.Client(
api_key=api_key,
base_url=api_url,
default_headers=values.get("default_headers"),
)
values["_async_client"] = anthropic.AsyncClient(
api_key=api_key, base_url=api_url
api_key=api_key,
base_url=api_url,
default_headers=values.get("default_headers"),
)
return values
@@ -232,6 +311,7 @@ class ChatAnthropic(BaseChatModel):
"stop_sequences": stop,
"system": system,
**self.model_kwargs,
**kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}
@@ -245,6 +325,13 @@ class ChatAnthropic(BaseChatModel):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = self._generate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
return
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
@@ -260,6 +347,13 @@ class ChatAnthropic(BaseChatModel):
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._format_params(messages=messages, stop=stop, **kwargs)
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
result = await self._agenerate(
messages, stop=stop, run_manager=run_manager, **kwargs
)
yield cast(ChatGenerationChunk, result.generations[0])
return
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
@@ -273,8 +367,12 @@ class ChatAnthropic(BaseChatModel):
llm_output = {
k: v for k, v in data_dict.items() if k not in ("content", "role", "type")
}
if len(content) == 1 and content[0]["type"] == "text":
msg = AIMessage(content=content[0]["text"])
else:
msg = AIMessage(content=content)
return ChatResult(
generations=[ChatGeneration(message=AIMessage(content=content[0]["text"]))],
generations=[ChatGeneration(message=msg)],
llm_output=llm_output,
)
@@ -285,12 +383,17 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
data = self._client.messages.create(**params)
return self._format_output(data, **kwargs)
@@ -301,15 +404,91 @@ class ChatAnthropic(BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
params = self._format_params(messages=messages, stop=stop, **kwargs)
if self.streaming:
if "extra_body" in params and params["extra_body"].get("tools"):
warnings.warn(
"stream: Tool use is not yet supported in streaming mode."
)
else:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
data = await self._async_client.messages.create(**params)
return self._format_output(data, **kwargs)
@beta()
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
Args:
tools: A list of tool definitions to bind to this chat model.
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
models, callables, and BaseTools will be automatically converted to
their schema dictionary representation.
**kwargs: Any additional parameters to bind.
"""
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
extra_body = kwargs.pop("extra_body", {})
extra_body["tools"] = formatted_tools
return self.bind(extra_body=extra_body, **kwargs)
@beta()
def with_structured_output(
self,
schema: Union[Dict, Type[BaseModel]],
*,
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
llm = self.bind_tools([schema])
if isinstance(schema, type) and issubclass(schema, BaseModel):
output_parser = ToolsOutputParser(
first_tool_only=True, pydantic_schemas=[schema]
)
else:
output_parser = ToolsOutputParser(first_tool_only=True, args_only=True)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
class AnthropicTool(TypedDict):
name: str
description: str
input_schema: Dict[str, Any]
def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool],
) -> AnthropicTool:
# already in Anthropic tool format
if isinstance(tool, dict) and all(
k in tool for k in ("name", "description", "input_schema")
):
return AnthropicTool(tool) # type: ignore
else:
formatted = convert_to_openai_tool(tool)["function"]
return AnthropicTool(
name=formatted["name"],
description=formatted["description"],
input_schema=formatted["parameters"],
)
@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):

View File

@@ -1,38 +1,13 @@
import json
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Sequence,
Type,
Union,
cast,
)
from langchain_core._api.beta_decorator import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.messages import (
AIMessage,
BaseMessage,
BaseMessageChunk,
SystemMessage,
)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_core._api import deprecated
from langchain_core.pydantic_v1 import Field
from langchain_anthropic.chat_models import ChatAnthropic
@@ -168,143 +143,16 @@ def _xml_to_tool_calls(elem: Any, tools: List[Dict]) -> List[Dict[str, Any]]:
return [_xml_to_function_call(invoke, tools) for invoke in invokes]
@beta()
@deprecated(
"0.1.5",
removal="0.2.0",
alternative="ChatAnthropic",
message=(
"Tool-calling is now officially supported by the Anthropic API so this "
"workaround is no longer needed."
),
)
class ChatAnthropicTools(ChatAnthropic):
"""Chat model for interacting with Anthropic functions."""
_xmllib: Any = Field(default=None)
@root_validator()
def check_xml_lib(cls, values: Dict[str, Any]) -> Dict[str, Any]:
try:
# do this as an optional dep for temporary nature of this feature
import defusedxml.ElementTree as DET # type: ignore
values["_xmllib"] = DET
except ImportError:
raise ImportError(
"Could not import defusedxml python package. "
"Please install it using `pip install defusedxml`"
)
return values
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tools to the chat model."""
formatted_tools = [convert_to_openai_function(tool) for tool in tools]
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self, schema: Union[Dict, Type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
if kwargs:
raise ValueError("kwargs are not supported for with_structured_output")
llm = self.bind_tools([schema])
if isinstance(schema, type) and issubclass(schema, BaseModel):
# schema is pydantic
return llm | PydanticToolsParser(tools=[schema], first_tool_only=True)
else:
# schema is dict
key_name = convert_to_openai_function(schema)["name"]
return llm | JsonOutputKeyToolsParser(
key_name=key_name, first_tool_only=True
)
def _format_params(
self,
*,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Dict:
tools: List[Dict] = kwargs.get("tools", None)
# experimental tools are sent in as part of system prompt, so if
# both are set, turn system prompt into tools + system prompt (tools first)
if tools:
tool_system = get_system_message(tools)
if messages[0].type == "system":
sys_content = messages[0].content
new_sys_content = f"{tool_system}\n\n{sys_content}"
messages = [SystemMessage(content=new_sys_content), *messages[1:]]
else:
messages = [SystemMessage(content=tool_system), *messages]
return super()._format_params(messages=messages, stop=stop, **kwargs)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
# streaming not supported for functions
result = self._generate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
to_yield = result.generations[0]
chunk = ChatGenerationChunk(
message=cast(BaseMessageChunk, to_yield.message),
generation_info=to_yield.generation_info,
)
if run_manager:
run_manager.on_llm_new_token(
cast(str, to_yield.message.content), chunk=chunk
)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
# streaming not supported for functions
result = await self._agenerate(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
to_yield = result.generations[0]
chunk = ChatGenerationChunk(
message=cast(BaseMessageChunk, to_yield.message),
generation_info=to_yield.generation_info,
)
if run_manager:
await run_manager.on_llm_new_token(
cast(str, to_yield.message.content), chunk=chunk
)
yield chunk
def _format_output(self, data: Any, **kwargs: Any) -> ChatResult:
"""Format the output of the model, parsing xml as a tool call."""
text = data.content[0].text
tools = kwargs.get("tools", None)
additional_kwargs: Dict[str, Any] = {}
if tools:
# parse out the xml from the text
try:
# get everything between <function_calls> and </function_calls>
start = text.find("<function_calls>")
end = text.find("</function_calls>") + len("</function_calls>")
xml_text = text[start:end]
xml = self._xmllib.fromstring(xml_text)
additional_kwargs["tool_calls"] = _xml_to_tool_calls(xml, tools)
text = ""
except Exception:
pass
return ChatResult(
generations=[
ChatGeneration(
message=AIMessage(content=text, additional_kwargs=additional_kwargs)
)
],
llm_output=data,
)

View File

@@ -0,0 +1,66 @@
from typing import Any, List, Optional, Type, TypedDict, cast
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
class _ToolCall(TypedDict):
name: str
args: dict
id: str
index: int
class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None
class Config:
extra = "forbid"
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Args:
result: A list of Generations to be parsed. The Generations are assumed
to be different candidate outputs for a single model input.
Returns:
Structured output.
"""
if not result or not isinstance(result[0], ChatGeneration):
return None if self.first_tool_only else []
tool_calls: List = _extract_tool_calls(result[0].message)
if self.pydantic_schemas:
tool_calls = [self._pydantic_parse(tc) for tc in tool_calls]
elif self.args_only:
tool_calls = [tc["args"] for tc in tool_calls]
else:
pass
if self.first_tool_only:
return tool_calls[0] if tool_calls else None
else:
return tool_calls
def _pydantic_parse(self, tool_call: _ToolCall) -> BaseModel:
cls_ = {schema.__name__: schema for schema in self.pydantic_schemas or []}[
tool_call["name"]
]
return cls_(**tool_call["args"])
def _extract_tool_calls(msg: BaseMessage) -> List[_ToolCall]:
if isinstance(msg.content, str):
return []
tool_calls = []
for i, block in enumerate(cast(List[dict], msg.content)):
if block["type"] != "tool_use":
continue
tool_calls.append(
_ToolCall(name=block["name"], args=block["input"], id=block["id"], index=i)
)
return tool_calls

View File

@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-anthropic"
version = "0.1.4"
version = "0.1.5"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"

View File

@@ -212,3 +212,47 @@ async def test_astreaming() -> None:
response = await llm.agenerate([[HumanMessage(content="I'm Pickle Rick")]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
def test_tool_use() -> None:
llm = ChatAnthropic(
model="claude-3-opus-20240229",
default_headers={"anthropic-beta": "tools-2024-04-04"},
)
llm_with_tools = llm.bind_tools(
[
{
"name": "get_weather",
"description": "Get weather report for a city",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
]
)
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
def test_with_structured_output() -> None:
llm = ChatAnthropic(
model="claude-3-opus-20240229",
default_headers={"anthropic-beta": "tools-2024-04-04"},
)
structured_llm = llm.with_structured_output(
{
"name": "get_weather",
"description": "Get weather report for a city",
"input_schema": {
"type": "object",
"properties": {"location": {"type": "string"}},
},
}
)
response = structured_llm.invoke("what's the weather in san francisco, ca")
assert isinstance(response, dict)
assert response["location"]

View File

@@ -1,13 +1,17 @@
"""Test chat model integration."""
import os
from typing import Any, Callable, Dict, Literal, Type
import pytest
from anthropic.types import ContentBlock, Message, Usage
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.tools import BaseTool
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.chat_models import _merge_messages, convert_to_anthropic_tool
os.environ["ANTHROPIC_API_KEY"] = "foo"
@@ -83,3 +87,175 @@ def test__format_output() -> None:
llm = ChatAnthropic(model="test", anthropic_api_key="test")
actual = llm._format_output(anthropic_msg)
assert expected == actual
def test__merge_messages() -> None:
messages = [
SystemMessage("foo"),
HumanMessage("bar"),
AIMessage(
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
ToolMessage("buz output", tool_call_id="1"),
ToolMessage("blah output", tool_call_id="2"),
HumanMessage("next thing"),
]
expected = [
SystemMessage("foo"),
HumanMessage("bar"),
AIMessage(
[
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "b"},
"type": "tool_use",
"id": "1",
"text": None,
"name": "buz",
},
{"text": "baz", "type": "text"},
{
"tool_input": {"a": "c"},
"type": "tool_use",
"id": "2",
"text": None,
"name": "blah",
},
]
),
HumanMessage(
[
{"type": "tool_result", "content": "buz output", "tool_use_id": "1"},
{"type": "tool_result", "content": "blah output", "tool_use_id": "2"},
{"type": "text", "text": "next thing"},
]
),
]
actual = _merge_messages(messages)
assert expected == actual
@pytest.fixture()
def pydantic() -> Type[BaseModel]:
class dummy_function(BaseModel):
"""dummy function"""
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
return dummy_function
@pytest.fixture()
def function() -> Callable:
def dummy_function(arg1: int, arg2: Literal["bar", "baz"]) -> None:
"""dummy function
Args:
arg1: foo
arg2: one of 'bar', 'baz'
"""
pass
return dummy_function
@pytest.fixture()
def dummy_tool() -> BaseTool:
class Schema(BaseModel):
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool):
args_schema: Type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"
def _run(self, *args: Any, **kwargs: Any) -> Any:
pass
return DummyFunction()
@pytest.fixture()
def json_schema() -> Dict:
return {
"title": "dummy_function",
"description": "dummy function",
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
}
@pytest.fixture()
def openai_function() -> Dict:
return {
"name": "dummy_function",
"description": "dummy function",
"parameters": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
def test_convert_to_anthropic_tool(
pydantic: Type[BaseModel],
function: Callable,
dummy_tool: BaseTool,
json_schema: Dict,
openai_function: Dict,
) -> None:
expected = {
"name": "dummy_function",
"description": "dummy function",
"input_schema": {
"type": "object",
"properties": {
"arg1": {"description": "foo", "type": "integer"},
"arg2": {
"description": "one of 'bar', 'baz'",
"enum": ["bar", "baz"],
"type": "string",
},
},
"required": ["arg1", "arg2"],
},
}
for fn in (pydantic, function, dummy_tool, json_schema, expected, openai_function):
actual = convert_to_anthropic_tool(fn) # type: ignore
assert actual == expected

View File

@@ -0,0 +1,72 @@
from typing import Any, List, Literal
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel
from langchain_anthropic.output_parsers import ToolsOutputParser
_CONTENT: List = [
{
"type": "text",
"text": "thought",
},
{"type": "tool_use", "input": {"bar": 0}, "id": "1", "name": "_Foo1"},
{
"type": "text",
"text": "thought",
},
{"type": "tool_use", "input": {"baz": "a"}, "id": "2", "name": "_Foo2"},
]
_RESULT: List = [ChatGeneration(message=AIMessage(_CONTENT))]
class _Foo1(BaseModel):
bar: int
class _Foo2(BaseModel):
baz: Literal["a", "b"]
def test_tools_output_parser() -> None:
output_parser = ToolsOutputParser()
expected = [
{"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1},
{"name": "_Foo2", "args": {"baz": "a"}, "id": "2", "index": 3},
]
actual = output_parser.parse_result(_RESULT)
assert expected == actual
def test_tools_output_parser_args_only() -> None:
output_parser = ToolsOutputParser(args_only=True)
expected = [
{"bar": 0},
{"baz": "a"},
]
actual = output_parser.parse_result(_RESULT)
assert expected == actual
expected = []
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
assert expected == actual
def test_tools_output_parser_first_tool_only() -> None:
output_parser = ToolsOutputParser(first_tool_only=True)
expected: Any = {"name": "_Foo1", "args": {"bar": 0}, "id": "1", "index": 1}
actual = output_parser.parse_result(_RESULT)
assert expected == actual
expected = None
actual = output_parser.parse_result([ChatGeneration(message=AIMessage(""))])
assert expected == actual
def test_tools_output_parser_pydantic() -> None:
output_parser = ToolsOutputParser(pydantic_schemas=[_Foo1, _Foo2])
expected = [_Foo1(bar=0), _Foo2(baz="a")]
actual = output_parser.parse_result(_RESULT)
assert expected == actual