add formatted_tools field

This commit is contained in:
Chester Curme 2024-11-06 09:54:57 -05:00
parent a72e9d14f0
commit 8662fd8c7d
2 changed files with 81 additions and 67 deletions

View File

@ -84,6 +84,15 @@ _message_type_lookups = {
} }
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def _format_image(image_url: str) -> Dict: def _format_image(image_url: str) -> Dict:
""" """
Formats an image of format data:image/jpeg;base64,{b64_string} Formats an image of format data:image/jpeg;base64,{b64_string}
@ -604,6 +613,9 @@ class ChatAnthropic(BaseChatModel):
message chunks will be generated during the stream including usage metadata. message chunks will be generated during the stream including usage metadata.
""" """
formatted_tools: List[AnthropicTool] = Field(default_factory=list)
"""Tools in Anthropic format to be passed to model invocations."""
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of chat model.""" """Return type of chat model."""
@ -690,6 +702,8 @@ class ChatAnthropic(BaseChatModel):
) -> Dict: ) -> Dict:
messages = self._convert_input(input_).to_messages() messages = self._convert_input(input_).to_messages()
system, formatted_messages = _format_messages(messages) system, formatted_messages = _format_messages(messages)
if self.formatted_tools and "tools" not in kwargs:
kwargs["tools"] = self.formatted_tools # type: ignore[assignment]
payload = { payload = {
"model": self.model, "model": self.model,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -955,6 +969,7 @@ class ChatAnthropic(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
self.formatted_tools = formatted_tools
if not tool_choice: if not tool_choice:
pass pass
elif isinstance(tool_choice, dict): elif isinstance(tool_choice, dict):
@ -1120,43 +1135,24 @@ class ChatAnthropic(BaseChatModel):
.. versionchanged:: 0.2.5 .. versionchanged:: 0.2.5
Uses Anthropic's token counting API to count tokens in messages. See: Uses Anthropic's token counting API to count tokens in messages. See:
https://docs.anthropic.com/en/api/messages-count-tokens https://docs.anthropic.com/en/docs/build-with-claude/token-counting
""" """
if any(
isinstance(tool, ToolMessage)
or (isinstance(tool, AIMessage) and tool.tool_calls)
for tool in messages
):
raise NotImplementedError(
"get_num_tokens_from_messages does not yet support counting tokens "
"in tool calls."
)
formatted_system, formatted_messages = _format_messages(messages) formatted_system, formatted_messages = _format_messages(messages)
kwargs: Dict[str, Any] = {}
if isinstance(formatted_system, str): if isinstance(formatted_system, str):
response = self._client.beta.messages.count_tokens( kwargs["system"] = formatted_system
betas=["token-counting-2024-11-01"], if self.formatted_tools:
model=self.model, kwargs["tools"] = self.formatted_tools
system=formatted_system,
messages=formatted_messages, # type: ignore[arg-type] response = self._client.beta.messages.count_tokens(
) betas=["token-counting-2024-11-01"],
else: model=self.model,
response = self._client.beta.messages.count_tokens( messages=formatted_messages, # type: ignore[arg-type]
betas=["token-counting-2024-11-01"], **kwargs,
model=self.model, )
messages=formatted_messages, # type: ignore[arg-type]
)
return response.input_tokens return response.input_tokens
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""
name: str
description: str
input_schema: Dict[str, Any]
cache_control: NotRequired[Dict[str, str]]
def convert_to_anthropic_tool( def convert_to_anthropic_tool(
tool: Union[Dict[str, Any], Type, Callable, BaseTool], tool: Union[Dict[str, Any], Type, Callable, BaseTool],
) -> AnthropicTool: ) -> AnthropicTool:

View File

@ -20,6 +20,7 @@ from langchain_core.tools import tool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.chat_models import convert_to_anthropic_tool
from tests.unit_tests._utils import FakeCallbackHandler from tests.unit_tests._utils import FakeCallbackHandler
MODEL_NAME = "claude-3-sonnet-20240229" MODEL_NAME = "claude-3-sonnet-20240229"
@ -368,18 +369,15 @@ async def test_astreaming() -> None:
def test_tool_use() -> None: def test_tool_use() -> None:
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
llm_with_tools = llm.bind_tools( tool_schema = {
[ "name": "get_weather",
{ "description": "Get weather report for a city",
"name": "get_weather", "input_schema": {
"description": "Get weather report for a city", "type": "object",
"input_schema": { "properties": {"location": {"type": "string"}},
"type": "object", },
"properties": {"location": {"type": "string"}}, }
}, llm_with_tools = llm.bind_tools([tool_schema])
}
]
)
response = llm_with_tools.invoke("what's the weather in san francisco, ca") response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage) assert isinstance(response, AIMessage)
assert isinstance(response.content, list) assert isinstance(response.content, list)
@ -441,6 +439,31 @@ def test_tool_use() -> None:
gathered = gathered + chunk # type: ignore gathered = gathered + chunk # type: ignore
assert len(chunks) > 1 assert len(chunks) > 1
# Test via init
llm_with_tools = ChatAnthropic(model=MODEL_NAME, formatted_tools=[tool_schema]) # type: ignore
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 1
# Test tool conversion
@tool
def get_weather(location: str) -> str:
"""Get weather report for a city"""
return "Sunny"
formatted_tool = convert_to_anthropic_tool(get_weather)
llm_with_tools = ChatAnthropic(
model=MODEL_NAME, # type: ignore[call-arg]
formatted_tools=[formatted_tool],
)
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
assert isinstance(response, AIMessage)
assert isinstance(response.content, list)
assert isinstance(response.tool_calls, list)
assert len(response.tool_calls) == 1
def test_anthropic_with_empty_text_block() -> None: def test_anthropic_with_empty_text_block() -> None:
"""Anthropic SDK can return an empty text block.""" """Anthropic SDK can return an empty text block."""
@ -518,30 +541,25 @@ def test_get_num_tokens_from_messages() -> None:
num_tokens = llm.get_num_tokens_from_messages(messages) num_tokens = llm.get_num_tokens_from_messages(messages)
assert num_tokens > 0 assert num_tokens > 0
# Test tool use (not yet supported) # Test tool use
messages = [ @tool
AIMessage( def get_weather(location: str) -> str:
content=[ """Get weather report for a city"""
{"text": "Let's see.", "type": "text"}, return "Sunny"
{
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk", ## via init
"input": {"location": "SF"}, formatted_tool = convert_to_anthropic_tool(get_weather)
"name": "get_weather", llm = ChatAnthropic(
"type": "tool_use", model="claude-3-5-haiku-20241022", # type: ignore[call-arg]
}, formatted_tools=[formatted_tool],
], )
tool_calls=[ num_tokens = llm.get_num_tokens_from_messages(messages)
{ assert num_tokens > 0
"name": "get_weather",
"args": {"location": "SF"}, ## via bind_tools
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk", llm_with_tools = llm.bind_tools([get_weather])
"type": "tool_call", num_tokens = llm_with_tools.get_num_tokens_from_messages(messages) # type: ignore[attr-defined]
}, assert num_tokens > 0
],
)
]
with pytest.raises(NotImplementedError):
num_tokens = llm.get_num_tokens_from_messages(messages)
class GetWeather(BaseModel): class GetWeather(BaseModel):