mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
anthropic[major]: support python 3.13 (#27916)
Last week Anthropic released version 0.39.0 of its python sdk, which enabled support for Python 3.13. This release deleted a legacy `client.count_tokens` method, which we currently access during init of the `Anthropic` LLM. Anthropic has replaced this functionality with the [client.beta.messages.count_tokens() API](https://github.com/anthropics/anthropic-sdk-python/pull/726). To enable support for `anthropic >= 0.39.0` and Python 3.13, here we drop support for the legacy token counting method, and add support for the new method via `ChatAnthropic.get_num_tokens_from_messages`. To fully support the token counting API, we update the signature of `get_num_tokens_from_message` to accept tools everywhere. --------- Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
759b6ed17a
commit
1538ee17f9
1
.github/scripts/check_diff.py
vendored
1
.github/scripts/check_diff.py
vendored
@ -37,7 +37,6 @@ IGNORED_PARTNERS = [
|
|||||||
PY_312_MAX_PACKAGES = [
|
PY_312_MAX_PACKAGES = [
|
||||||
f"libs/partners/{integration}"
|
f"libs/partners/{integration}"
|
||||||
for integration in [
|
for integration in [
|
||||||
"anthropic",
|
|
||||||
"chroma",
|
"chroma",
|
||||||
"couchbase",
|
"couchbase",
|
||||||
"huggingface",
|
"huggingface",
|
||||||
|
@ -5,10 +5,22 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from pydantic import Field, SecretStr, model_validator
|
from pydantic import Field, SecretStr, model_validator
|
||||||
|
|
||||||
@ -197,10 +209,20 @@ class ChatAnyscale(ChatOpenAI):
|
|||||||
encoding = tiktoken_.get_encoding(model)
|
encoding = tiktoken_.get_encoding(model)
|
||||||
return model, encoding
|
return model, encoding
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens with tiktoken package.
|
"""Calculate num tokens with tiktoken package.
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
"""
|
"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||||
|
)
|
||||||
if sys.version_info[1] <= 7:
|
if sys.version_info[1] <= 7:
|
||||||
return super().get_num_tokens_from_messages(messages)
|
return super().get_num_tokens_from_messages(messages)
|
||||||
model, encoding = self._get_encoding_model()
|
model, encoding = self._get_encoding_model()
|
||||||
|
@ -4,9 +4,21 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
|
import warnings
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Set,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
|
||||||
@ -138,11 +150,21 @@ class ChatEverlyAI(ChatOpenAI):
|
|||||||
encoding = tiktoken_.get_encoding(model)
|
encoding = tiktoken_.get_encoding(model)
|
||||||
return model, encoding
|
return model, encoding
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens with tiktoken package.
|
"""Calculate num tokens with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||||
|
)
|
||||||
if sys.version_info[1] <= 7:
|
if sys.version_info[1] <= 7:
|
||||||
return super().get_num_tokens_from_messages(messages)
|
return super().get_num_tokens_from_messages(messages)
|
||||||
model, encoding = self._get_encoding_model()
|
model, encoding = self._get_encoding_model()
|
||||||
|
@ -46,6 +46,7 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import Runnable
|
from langchain_core.runnables import Runnable
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils import (
|
from langchain_core.utils import (
|
||||||
get_from_dict_or_env,
|
get_from_dict_or_env,
|
||||||
get_pydantic_field_names,
|
get_pydantic_field_names,
|
||||||
@ -644,11 +645,21 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
_, encoding_model = self._get_encoding_model()
|
_, encoding_model = self._get_encoding_model()
|
||||||
return encoding_model.encode(text)
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||||
|
)
|
||||||
if sys.version_info[1] <= 7:
|
if sys.version_info[1] <= 7:
|
||||||
return super().get_num_tokens_from_messages(messages)
|
return super().get_num_tokens_from_messages(messages)
|
||||||
model, encoding = self._get_encoding_model()
|
model, encoding = self._get_encoding_model()
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from functools import cache
|
from functools import cache
|
||||||
@ -364,17 +365,31 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
return len(self.get_token_ids(text))
|
return len(self.get_token_ids(text))
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
"""Get the number of tokens in the messages.
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
Useful for checking if an input fits in a model's context window.
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: The message inputs to tokenize.
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The sum of the number of tokens across the messages.
|
The sum of the number of tokens across the messages.
|
||||||
"""
|
"""
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
return sum([self.get_num_tokens(get_buffer_string([m])) for m in messages])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import typing
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -19,6 +22,7 @@ from langchain_core.messages.utils import (
|
|||||||
merge_message_runs,
|
merge_message_runs,
|
||||||
trim_messages,
|
trim_messages,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
||||||
@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class FakeTokenCountingModel(FakeChatModel):
|
class FakeTokenCountingModel(FakeChatModel):
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[
|
||||||
|
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
return dummy_token_counter(messages)
|
return dummy_token_counter(messages)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,7 +3,6 @@ from unittest import mock
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.language_models import BaseChatModel
|
from langchain_core.language_models import BaseChatModel
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langchain_core.prompts import ChatPromptTemplate
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain_core.runnables import RunnableConfig, RunnableSequence
|
from langchain_core.runnables import RunnableConfig, RunnableSequence
|
||||||
from pydantic import SecretStr
|
from pydantic import SecretStr
|
||||||
@ -180,9 +179,6 @@ def test_configurable_with_default() -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
|
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
|
||||||
# Anthropic defaults to using `transformers` for token counting.
|
|
||||||
with pytest.raises(ImportError):
|
|
||||||
model_with_config.get_num_tokens_from_messages([(HumanMessage("foo"))]) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
|
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
|
||||||
"name": None,
|
"name": None,
|
||||||
|
@ -21,7 +21,7 @@ from typing import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
import anthropic
|
import anthropic
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import beta, deprecated
|
||||||
from langchain_core.callbacks import (
|
from langchain_core.callbacks import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -1113,6 +1113,41 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
else:
|
else:
|
||||||
return llm | output_parser
|
return llm | output_parser
|
||||||
|
|
||||||
|
@beta()
|
||||||
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Count tokens in a sequence of input messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
|
.. versionchanged:: 0.3.0
|
||||||
|
|
||||||
|
Uses Anthropic's token counting API to count tokens in messages. See:
|
||||||
|
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
|
||||||
|
"""
|
||||||
|
formatted_system, formatted_messages = _format_messages(messages)
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if isinstance(formatted_system, str):
|
||||||
|
kwargs["system"] = formatted_system
|
||||||
|
if tools:
|
||||||
|
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
|
||||||
|
|
||||||
|
response = self._client.beta.messages.count_tokens(
|
||||||
|
betas=["token-counting-2024-11-01"],
|
||||||
|
model=self.model,
|
||||||
|
messages=formatted_messages, # type: ignore[arg-type]
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return response.input_tokens
|
||||||
|
|
||||||
|
|
||||||
class AnthropicTool(TypedDict):
|
class AnthropicTool(TypedDict):
|
||||||
"""Anthropic tool definition."""
|
"""Anthropic tool definition."""
|
||||||
|
@ -109,7 +109,6 @@ class _AnthropicCommon(BaseLanguageModel):
|
|||||||
)
|
)
|
||||||
self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
|
self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
|
||||||
self.AI_PROMPT = anthropic.AI_PROMPT
|
self.AI_PROMPT = anthropic.AI_PROMPT
|
||||||
self.count_tokens = self.client.count_tokens
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -375,9 +374,11 @@ class AnthropicLLM(LLM, _AnthropicCommon):
|
|||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
"""Calculate number of tokens."""
|
"""Calculate number of tokens."""
|
||||||
if not self.count_tokens:
|
raise NotImplementedError(
|
||||||
raise NameError("Please ensure the anthropic package is loaded")
|
"Anthropic's legacy count_tokens method was removed in anthropic 0.39.0 "
|
||||||
return self.count_tokens(text)
|
"and langchain-anthropic 0.3.0. Please use "
|
||||||
|
"ChatAnthropic.get_num_tokens_from_messages instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@deprecated(since="0.1.0", removal="0.3.0", alternative="AnthropicLLM")
|
@deprecated(since="0.1.0", removal="0.3.0", alternative="AnthropicLLM")
|
||||||
|
843
libs/partners/anthropic/poetry.lock
generated
843
libs/partners/anthropic/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
|
|||||||
|
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "langchain-anthropic"
|
name = "langchain-anthropic"
|
||||||
version = "0.2.4"
|
version = "0.3.0"
|
||||||
description = "An integration package connecting AnthropicMessages and LangChain"
|
description = "An integration package connecting AnthropicMessages and LangChain"
|
||||||
authors = []
|
authors = []
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -20,7 +20,7 @@ disallow_untyped_defs = "True"
|
|||||||
|
|
||||||
[tool.poetry.dependencies]
|
[tool.poetry.dependencies]
|
||||||
python = ">=3.9,<4.0"
|
python = ">=3.9,<4.0"
|
||||||
anthropic = ">=0.30.0,<1"
|
anthropic = ">=0.39.0,<1"
|
||||||
langchain-core = "^0.3.15"
|
langchain-core = "^0.3.15"
|
||||||
pydantic = "^2.7.4"
|
pydantic = "^2.7.4"
|
||||||
|
|
||||||
|
@ -317,7 +317,7 @@ async def test_anthropic_async_streaming_callback() -> None:
|
|||||||
def test_anthropic_multimodal() -> None:
|
def test_anthropic_multimodal() -> None:
|
||||||
"""Test that multimodal inputs are handled correctly."""
|
"""Test that multimodal inputs are handled correctly."""
|
||||||
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
|
||||||
messages = [
|
messages: list[BaseMessage] = [
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=[
|
content=[
|
||||||
{
|
{
|
||||||
@ -334,6 +334,8 @@ def test_anthropic_multimodal() -> None:
|
|||||||
response = chat.invoke(messages)
|
response = chat.invoke(messages)
|
||||||
assert isinstance(response, AIMessage)
|
assert isinstance(response, AIMessage)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
|
num_tokens = chat.get_num_tokens_from_messages(messages)
|
||||||
|
assert num_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
def test_streaming() -> None:
|
def test_streaming() -> None:
|
||||||
@ -505,6 +507,60 @@ def test_with_structured_output() -> None:
|
|||||||
assert response["location"]
|
assert response["location"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_num_tokens_from_messages() -> None:
|
||||||
|
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]
|
||||||
|
|
||||||
|
# Test simple case
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a scientist"),
|
||||||
|
HumanMessage(content="Hello, Claude"),
|
||||||
|
]
|
||||||
|
num_tokens = llm.get_num_tokens_from_messages(messages)
|
||||||
|
assert num_tokens > 0
|
||||||
|
|
||||||
|
# Test tool use
|
||||||
|
@tool(parse_docstring=True)
|
||||||
|
def get_weather(location: str) -> str:
|
||||||
|
"""Get the current weather in a given location
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The city and state, e.g. San Francisco, CA
|
||||||
|
"""
|
||||||
|
return "Sunny"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="What's the weather like in San Francisco?"),
|
||||||
|
]
|
||||||
|
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
|
||||||
|
assert num_tokens > 0
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
HumanMessage(content="What's the weather like in San Francisco?"),
|
||||||
|
AIMessage(
|
||||||
|
content=[
|
||||||
|
{"text": "Let's see.", "type": "text"},
|
||||||
|
{
|
||||||
|
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||||
|
"input": {"location": "SF"},
|
||||||
|
"name": "get_weather",
|
||||||
|
"type": "tool_use",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"args": {"location": "SF"},
|
||||||
|
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
|
||||||
|
"type": "tool_call",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
|
||||||
|
]
|
||||||
|
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
|
||||||
|
assert num_tokens > 0
|
||||||
|
|
||||||
|
|
||||||
class GetWeather(BaseModel):
|
class GetWeather(BaseModel):
|
||||||
"""Get the current weather in a given location"""
|
"""Get the current weather in a given location"""
|
||||||
|
|
||||||
|
@ -331,7 +331,7 @@ def dummy_tool() -> BaseTool:
|
|||||||
arg1: int = Field(..., description="foo")
|
arg1: int = Field(..., description="foo")
|
||||||
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
|
||||||
|
|
||||||
class DummyFunction(BaseTool):
|
class DummyFunction(BaseTool): # type: ignore[override]
|
||||||
args_schema: Type[BaseModel] = Schema
|
args_schema: Type[BaseModel] = Schema
|
||||||
name: str = "dummy_function"
|
name: str = "dummy_function"
|
||||||
description: str = "dummy function"
|
description: str = "dummy function"
|
||||||
|
@ -886,8 +886,13 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
_, encoding_model = self._get_encoding_model()
|
_, encoding_model = self._get_encoding_model()
|
||||||
return encoding_model.encode(text)
|
return encoding_model.encode(text)
|
||||||
|
|
||||||
# TODO: Count bound tools as part of input.
|
def get_num_tokens_from_messages(
|
||||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||||
|
|
||||||
**Requirements**: You must have the ``pillow`` installed if you want to count
|
**Requirements**: You must have the ``pillow`` installed if you want to count
|
||||||
@ -897,7 +902,18 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
counting.
|
counting.
|
||||||
|
|
||||||
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
|
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
|
||||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
"""
|
||||||
|
# TODO: Count bound tools as part of input.
|
||||||
|
if tools is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"Counting tokens in tool schemas is not yet supported. Ignoring tools."
|
||||||
|
)
|
||||||
if sys.version_info[1] <= 7:
|
if sys.version_info[1] <= 7:
|
||||||
return super().get_num_tokens_from_messages(messages)
|
return super().get_num_tokens_from_messages(messages)
|
||||||
model, encoding = self._get_encoding_model()
|
model, encoding = self._get_encoding_model()
|
||||||
|
Loading…
Reference in New Issue
Block a user