diff --git a/libs/partners/anthropic/langchain_anthropic/chat_models.py b/libs/partners/anthropic/langchain_anthropic/chat_models.py index 0784ff2bdeb..9a1885a3680 100644 --- a/libs/partners/anthropic/langchain_anthropic/chat_models.py +++ b/libs/partners/anthropic/langchain_anthropic/chat_models.py @@ -84,6 +84,15 @@ _message_type_lookups = { } +class AnthropicTool(TypedDict): + """Anthropic tool definition.""" + + name: str + description: str + input_schema: Dict[str, Any] + cache_control: NotRequired[Dict[str, str]] + + def _format_image(image_url: str) -> Dict: """ Formats an image of format data:image/jpeg;base64,{b64_string} @@ -604,6 +613,9 @@ class ChatAnthropic(BaseChatModel): message chunks will be generated during the stream including usage metadata. """ + formatted_tools: List[AnthropicTool] = Field(default_factory=list) + """Tools in Anthropic format to be passed to model invocations.""" + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -690,6 +702,8 @@ class ChatAnthropic(BaseChatModel): ) -> Dict: messages = self._convert_input(input_).to_messages() system, formatted_messages = _format_messages(messages) + if self.formatted_tools and "tools" not in kwargs: + kwargs["tools"] = self.formatted_tools # type: ignore[assignment] payload = { "model": self.model, "max_tokens": self.max_tokens, @@ -955,6 +969,7 @@ class ChatAnthropic(BaseChatModel): """ # noqa: E501 formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools] + self.formatted_tools = formatted_tools if not tool_choice: pass elif isinstance(tool_choice, dict): @@ -1120,43 +1135,24 @@ class ChatAnthropic(BaseChatModel): .. versionchanged:: 0.2.5 Uses Anthropic's token counting API to count tokens in messages. See: - https://docs.anthropic.com/en/api/messages-count-tokens + https://docs.anthropic.com/en/docs/build-with-claude/token-counting """ - if any( - isinstance(tool, ToolMessage) - or (isinstance(tool, AIMessage) and tool.tool_calls) - for tool in messages - ): - raise NotImplementedError( - "get_num_tokens_from_messages does not yet support counting tokens " - "in tool calls." - ) formatted_system, formatted_messages = _format_messages(messages) + kwargs: Dict[str, Any] = {} if isinstance(formatted_system, str): - response = self._client.beta.messages.count_tokens( - betas=["token-counting-2024-11-01"], - model=self.model, - system=formatted_system, - messages=formatted_messages, # type: ignore[arg-type] - ) - else: - response = self._client.beta.messages.count_tokens( - betas=["token-counting-2024-11-01"], - model=self.model, - messages=formatted_messages, # type: ignore[arg-type] - ) + kwargs["system"] = formatted_system + if self.formatted_tools: + kwargs["tools"] = self.formatted_tools + + response = self._client.beta.messages.count_tokens( + betas=["token-counting-2024-11-01"], + model=self.model, + messages=formatted_messages, # type: ignore[arg-type] + **kwargs, + ) return response.input_tokens -class AnthropicTool(TypedDict): - """Anthropic tool definition.""" - - name: str - description: str - input_schema: Dict[str, Any] - cache_control: NotRequired[Dict[str, str]] - - def convert_to_anthropic_tool( tool: Union[Dict[str, Any], Type, Callable, BaseTool], ) -> AnthropicTool: diff --git a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py index 47736880b25..73ae6d9d0c6 100644 --- a/libs/partners/anthropic/tests/integration_tests/test_chat_models.py +++ b/libs/partners/anthropic/tests/integration_tests/test_chat_models.py @@ -20,6 +20,7 @@ from langchain_core.tools import tool from pydantic import BaseModel, Field from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages +from langchain_anthropic.chat_models import convert_to_anthropic_tool from tests.unit_tests._utils import FakeCallbackHandler MODEL_NAME = "claude-3-sonnet-20240229" @@ -368,18 +369,15 @@ async def test_astreaming() -> None: def test_tool_use() -> None: llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg] - llm_with_tools = llm.bind_tools( - [ - { - "name": "get_weather", - "description": "Get weather report for a city", - "input_schema": { - "type": "object", - "properties": {"location": {"type": "string"}}, - }, - } - ] - ) + tool_schema = { + "name": "get_weather", + "description": "Get weather report for a city", + "input_schema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + }, + } + llm_with_tools = llm.bind_tools([tool_schema]) response = llm_with_tools.invoke("what's the weather in san francisco, ca") assert isinstance(response, AIMessage) assert isinstance(response.content, list) @@ -441,6 +439,31 @@ def test_tool_use() -> None: gathered = gathered + chunk # type: ignore assert len(chunks) > 1 + # Test via init + llm_with_tools = ChatAnthropic(model=MODEL_NAME, formatted_tools=[tool_schema]) # type: ignore + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert isinstance(response.content, list) + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + + # Test tool conversion + @tool + def get_weather(location: str) -> str: + """Get weather report for a city""" + return "Sunny" + + formatted_tool = convert_to_anthropic_tool(get_weather) + llm_with_tools = ChatAnthropic( + model=MODEL_NAME, # type: ignore[call-arg] + formatted_tools=[formatted_tool], + ) + response = llm_with_tools.invoke("what's the weather in san francisco, ca") + assert isinstance(response, AIMessage) + assert isinstance(response.content, list) + assert isinstance(response.tool_calls, list) + assert len(response.tool_calls) == 1 + def test_anthropic_with_empty_text_block() -> None: """Anthropic SDK can return an empty text block.""" @@ -518,30 +541,25 @@ def test_get_num_tokens_from_messages() -> None: num_tokens = llm.get_num_tokens_from_messages(messages) assert num_tokens > 0 - # Test tool use (not yet supported) - messages = [ - AIMessage( - content=[ - {"text": "Let's see.", "type": "text"}, - { - "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", - "input": {"location": "SF"}, - "name": "get_weather", - "type": "tool_use", - }, - ], - tool_calls=[ - { - "name": "get_weather", - "args": {"location": "SF"}, - "id": "toolu_01V6d6W32QGGSmQm4BT98EKk", - "type": "tool_call", - }, - ], - ) - ] - with pytest.raises(NotImplementedError): - num_tokens = llm.get_num_tokens_from_messages(messages) + # Test tool use + @tool + def get_weather(location: str) -> str: + """Get weather report for a city""" + return "Sunny" + + ## via init + formatted_tool = convert_to_anthropic_tool(get_weather) + llm = ChatAnthropic( + model="claude-3-5-haiku-20241022", # type: ignore[call-arg] + formatted_tools=[formatted_tool], + ) + num_tokens = llm.get_num_tokens_from_messages(messages) + assert num_tokens > 0 + + ## via bind_tools + llm_with_tools = llm.bind_tools([get_weather]) + num_tokens = llm_with_tools.get_num_tokens_from_messages(messages) # type: ignore[attr-defined] + assert num_tokens > 0 class GetWeather(BaseModel):