mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
pass tools into get_num_tokens
This commit is contained in:
parent
077199c5de
commit
668e4c68ec
@ -15,6 +15,7 @@ from typing import (
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypedDict,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@ -71,7 +72,7 @@ from pydantic import (
|
||||
SecretStr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
from typing_extensions import NotRequired, Self
|
||||
|
||||
from langchain_anthropic.output_parsers import extract_tool_calls
|
||||
|
||||
@ -83,15 +84,6 @@ _message_type_lookups = {
|
||||
}
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
cache_control: NotRequired[Dict[str, str]]
|
||||
|
||||
|
||||
def _format_image(image_url: str) -> Dict:
|
||||
"""
|
||||
Formats an image of format data:image/jpeg;base64,{b64_string}
|
||||
@ -612,9 +604,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
message chunks will be generated during the stream including usage metadata.
|
||||
"""
|
||||
|
||||
formatted_tools: List[AnthropicTool] = Field(default_factory=list)
|
||||
"""Tools in Anthropic format to be passed to model invocations."""
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
@ -701,8 +690,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
) -> Dict:
|
||||
messages = self._convert_input(input_).to_messages()
|
||||
system, formatted_messages = _format_messages(messages)
|
||||
if self.formatted_tools and "tools" not in kwargs:
|
||||
kwargs["tools"] = self.formatted_tools # type: ignore[assignment]
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"max_tokens": self.max_tokens,
|
||||
@ -968,7 +955,6 @@ class ChatAnthropic(BaseChatModel):
|
||||
|
||||
""" # noqa: E501
|
||||
formatted_tools = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
self.formatted_tools = formatted_tools
|
||||
if not tool_choice:
|
||||
pass
|
||||
elif isinstance(tool_choice, dict):
|
||||
@ -1128,7 +1114,13 @@ class ChatAnthropic(BaseChatModel):
|
||||
return llm | output_parser
|
||||
|
||||
@beta()
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
def get_num_tokens_from_messages(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
tools: Optional[
|
||||
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||
] = None,
|
||||
) -> int:
|
||||
"""Count tokens in a sequence of input messages.
|
||||
|
||||
.. versionchanged:: 0.2.5
|
||||
@ -1140,8 +1132,8 @@ class ChatAnthropic(BaseChatModel):
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if isinstance(formatted_system, str):
|
||||
kwargs["system"] = formatted_system
|
||||
if self.formatted_tools:
|
||||
kwargs["tools"] = self.formatted_tools
|
||||
if tools:
|
||||
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||
|
||||
response = self._client.beta.messages.count_tokens(
|
||||
betas=["token-counting-2024-11-01"],
|
||||
@ -1152,6 +1144,15 @@ class ChatAnthropic(BaseChatModel):
|
||||
return response.input_tokens
|
||||
|
||||
|
||||
class AnthropicTool(TypedDict):
|
||||
"""Anthropic tool definition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
input_schema: Dict[str, Any]
|
||||
cache_control: NotRequired[Dict[str, str]]
|
||||
|
||||
|
||||
def convert_to_anthropic_tool(
|
||||
tool: Union[Dict[str, Any], Type, Callable, BaseTool],
|
||||
) -> AnthropicTool:
|
||||
|
@ -20,7 +20,6 @@ from langchain_core.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
|
||||
from langchain_anthropic.chat_models import convert_to_anthropic_tool
|
||||
from tests.unit_tests._utils import FakeCallbackHandler
|
||||
|
||||
MODEL_NAME = "claude-3-sonnet-20240229"
|
||||
@ -369,15 +368,18 @@ async def test_astreaming() -> None:
|
||||
|
||||
def test_tool_use() -> None:
|
||||
llm = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
||||
tool_schema = {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
llm_with_tools = llm.bind_tools([tool_schema])
|
||||
llm_with_tools = llm.bind_tools(
|
||||
[
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather report for a city",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {"location": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
@ -439,31 +441,6 @@ def test_tool_use() -> None:
|
||||
gathered = gathered + chunk # type: ignore
|
||||
assert len(chunks) > 1
|
||||
|
||||
# Test via init
|
||||
llm_with_tools = ChatAnthropic(model=MODEL_NAME, formatted_tools=[tool_schema]) # type: ignore
|
||||
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
assert isinstance(response.tool_calls, list)
|
||||
assert len(response.tool_calls) == 1
|
||||
|
||||
# Test tool conversion
|
||||
@tool
|
||||
def get_weather(location: str) -> str:
|
||||
"""Get weather report for a city"""
|
||||
return "Sunny"
|
||||
|
||||
formatted_tool = convert_to_anthropic_tool(get_weather)
|
||||
llm_with_tools = ChatAnthropic(
|
||||
model=MODEL_NAME, # type: ignore[call-arg]
|
||||
formatted_tools=[formatted_tool],
|
||||
)
|
||||
response = llm_with_tools.invoke("what's the weather in san francisco, ca")
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, list)
|
||||
assert isinstance(response.tool_calls, list)
|
||||
assert len(response.tool_calls) == 1
|
||||
|
||||
|
||||
def test_anthropic_with_empty_text_block() -> None:
|
||||
"""Anthropic SDK can return an empty text block."""
|
||||
@ -570,19 +547,7 @@ def test_get_num_tokens_from_messages() -> None:
|
||||
),
|
||||
ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
|
||||
]
|
||||
|
||||
## via init
|
||||
formatted_tool = convert_to_anthropic_tool(get_weather)
|
||||
llm = ChatAnthropic(
|
||||
model="claude-3-5-haiku-20241022", # type: ignore[call-arg]
|
||||
formatted_tools=[formatted_tool],
|
||||
)
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||
assert num_tokens > 0
|
||||
|
||||
## via bind_tools
|
||||
llm_with_tools = llm.bind_tools([get_weather])
|
||||
num_tokens = llm_with_tools.get_num_tokens_from_messages(messages) # type: ignore[attr-defined]
|
||||
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
|
||||
assert num_tokens > 0
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user