Compare commits

...

25 Commits

Author SHA1 Message Date
ccurme
36f5bfb552 Merge branch 'cc/anthropic_313' into cc/tool_attr 2024-11-12 11:04:42 -05:00
Chester Curme
29ba5fd1f1 community 2024-11-12 10:29:56 -05:00
Chester Curme
62a5f993f2 openai 2024-11-12 10:29:51 -05:00
Chester Curme
db7b518308 update anthropic 2024-11-12 10:29:43 -05:00
Chester Curme
e4bfc84d6e update signature in core 2024-11-12 10:29:34 -05:00
Chester Curme
a4857bf09b update openai 2024-11-12 10:16:54 -05:00
Chester Curme
7de02c5997 update anthropic 2024-11-12 10:16:47 -05:00
Chester Curme
76971094ab update community 2024-11-12 10:16:34 -05:00
Chester Curme
42042abd82 update core 2024-11-12 09:56:35 -05:00
Chester Curme
668e4c68ec pass tools into get_num_tokens 2024-11-09 14:48:40 -05:00
Chester Curme
077199c5de Merge branch 'cc/anthropic_313' into cc/tool_attr 2024-11-09 14:33:57 -05:00
Chester Curme
826040f8b8 Merge branch 'master' into cc/anthropic_313 2024-11-09 14:33:46 -05:00
Chester Curme
1bf1ab7986 fix TypedDict import 2024-11-06 10:13:48 -05:00
Chester Curme
c8d96ad346 update test 2024-11-06 10:02:00 -05:00
Chester Curme
8662fd8c7d add formatted_tools field 2024-11-06 09:54:57 -05:00
Chester Curme
a72e9d14f0 bump to 0.3 2024-11-05 12:57:24 -05:00
Chester Curme
7f5c21dc0c bump anthropic dep 2024-11-05 12:57:13 -05:00
Chester Curme
ffb26b3298 lint 2024-11-05 11:56:37 -05:00
Chester Curme
84b4ea2198 update extended test 2024-11-05 11:51:39 -05:00
Chester Curme
c54b676a04 update docstring 2024-11-05 11:36:57 -05:00
Chester Curme
a3350f4174 test with 3.13 in CI 2024-11-05 11:24:30 -05:00
Chester Curme
5fc4ec6bc9 lint 2024-11-05 11:17:44 -05:00
Chester Curme
caed4e4ce8 implement ChatAnthropic.get_num_tokens_from_messages 2024-11-05 11:09:07 -05:00
Chester Curme
ff2ef48b35 drop support for Anthropic.get_num_tokens 2024-11-05 11:08:05 -05:00
Chester Curme
b2e8df4cea lock 2024-11-05 11:07:35 -05:00
14 changed files with 521 additions and 542 deletions

View File

@@ -37,7 +37,6 @@ IGNORED_PARTNERS = [
PY_312_MAX_PACKAGES = [
f"libs/partners/{integration}"
for integration in [
"anthropic",
"chroma",
"couchbase",
"huggingface",

View File

@@ -5,10 +5,22 @@ from __future__ import annotations
import logging
import os
import sys
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Set,
Type,
Union,
)
import requests
from langchain_core.messages import BaseMessage
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, SecretStr, model_validator
@@ -197,10 +209,18 @@ class ChatAnyscale(ChatOpenAI):
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
"""
if tools is not None:
warnings.warn("Counting tokens in tool schemas is not yet supported.")
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()

View File

@@ -4,9 +4,21 @@ from __future__ import annotations
import logging
import sys
from typing import TYPE_CHECKING, Any, Dict, Optional, Set
import warnings
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
Set,
Type,
Union,
)
from langchain_core.messages import BaseMessage
from langchain_core.tools import BaseTool
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, model_validator
@@ -138,11 +150,19 @@ class ChatEverlyAI(ChatOpenAI):
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if tools is not None:
warnings.warn("Counting tokens in tool schemas is not yet supported.")
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()

View File

@@ -46,6 +46,7 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
get_from_dict_or_env,
get_pydantic_field_names,
@@ -644,11 +645,19 @@ class ChatOpenAI(BaseChatModel):
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if tools is not None:
warnings.warn("Counting tokens in tool schemas is not yet supported.")
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()

View File

@@ -364,13 +364,22 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.

View File

@@ -1,5 +1,8 @@
import base64
import json
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pytest
@@ -19,6 +22,7 @@ from langchain_core.messages.utils import (
merge_message_runs,
trim_messages,
)
from langchain_core.tools import BaseTool
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
class FakeTokenCountingModel(FakeChatModel):
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
]
] = None,
) -> int:
return dummy_token_counter(messages)

View File

@@ -3,7 +3,6 @@ from unittest import mock
import pytest
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableConfig, RunnableSequence
from pydantic import SecretStr
@@ -180,9 +179,6 @@ def test_configurable_with_default() -> None:
)
assert model_with_config.model == "claude-3-sonnet-20240229" # type: ignore[attr-defined]
# Anthropic defaults to using `transformers` for token counting.
with pytest.raises(ImportError):
model_with_config.get_num_tokens_from_messages([(HumanMessage("foo"))]) # type: ignore[attr-defined]
assert model_with_config.model_dump() == { # type: ignore[attr-defined]
"name": None,

View File

@@ -21,7 +21,7 @@ from typing import (
)
import anthropic
from langchain_core._api import deprecated
from langchain_core._api import beta, deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@@ -1113,6 +1113,41 @@ class ChatAnthropic(BaseChatModel):
else:
return llm | output_parser
@beta()
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Count tokens in a sequence of input messages.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
.. versionchanged:: 0.2.5
Uses Anthropic's token counting API to count tokens in messages. See:
https://docs.anthropic.com/en/docs/build-with-claude/token-counting
"""
formatted_system, formatted_messages = _format_messages(messages)
kwargs: Dict[str, Any] = {}
if isinstance(formatted_system, str):
kwargs["system"] = formatted_system
if tools:
kwargs["tools"] = [convert_to_anthropic_tool(tool) for tool in tools]
response = self._client.beta.messages.count_tokens(
betas=["token-counting-2024-11-01"],
model=self.model,
messages=formatted_messages, # type: ignore[arg-type]
**kwargs,
)
return response.input_tokens
class AnthropicTool(TypedDict):
"""Anthropic tool definition."""

View File

@@ -109,7 +109,6 @@ class _AnthropicCommon(BaseLanguageModel):
)
self.HUMAN_PROMPT = anthropic.HUMAN_PROMPT
self.AI_PROMPT = anthropic.AI_PROMPT
self.count_tokens = self.client.count_tokens
return self
@property
@@ -375,9 +374,11 @@ class AnthropicLLM(LLM, _AnthropicCommon):
def get_num_tokens(self, text: str) -> int:
"""Calculate number of tokens."""
if not self.count_tokens:
raise NameError("Please ensure the anthropic package is loaded")
return self.count_tokens(text)
raise NotImplementedError(
"Anthropic's legacy count_tokens method was removed in anthropic 0.39.0 "
"and langchain-anthropic 0.2.5. Please use "
"ChatAnthropic.get_num_tokens_from_messages instead."
)
@deprecated(since="0.1.0", removal="0.3.0", alternative="AnthropicLLM")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "langchain-anthropic"
version = "0.2.4"
version = "0.3.0"
description = "An integration package connecting AnthropicMessages and LangChain"
authors = []
readme = "README.md"
@@ -20,7 +20,7 @@ disallow_untyped_defs = "True"
[tool.poetry.dependencies]
python = ">=3.9,<4.0"
anthropic = ">=0.30.0,<1"
anthropic = ">=0.39.0,<1"
langchain-core = "^0.3.15"
pydantic = "^2.7.4"

View File

@@ -317,7 +317,7 @@ async def test_anthropic_async_streaming_callback() -> None:
def test_anthropic_multimodal() -> None:
"""Test that multimodal inputs are handled correctly."""
chat = ChatAnthropic(model=MODEL_NAME) # type: ignore[call-arg]
messages = [
messages: list[BaseMessage] = [
HumanMessage(
content=[
{
@@ -334,6 +334,9 @@ def test_anthropic_multimodal() -> None:
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
num_tokens = chat.get_num_tokens_from_messages(messages)
assert num_tokens > 0
import pdb; pdb.set_trace()
def test_streaming() -> None:
@@ -505,6 +508,60 @@ def test_with_structured_output() -> None:
assert response["location"]
def test_get_num_tokens_from_messages() -> None:
llm = ChatAnthropic(model="claude-3-5-sonnet-20241022") # type: ignore[call-arg]
# Test simple case
messages = [
SystemMessage(content="You are a scientist"),
HumanMessage(content="Hello, Claude"),
]
num_tokens = llm.get_num_tokens_from_messages(messages)
assert num_tokens > 0
# Test tool use
@tool(parse_docstring=True)
def get_weather(location: str) -> str:
"""Get the current weather in a given location
Args:
location: The city and state, e.g. San Francisco, CA
"""
return "Sunny"
messages = [
HumanMessage(content="What's the weather like in San Francisco?"),
]
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
assert num_tokens > 0
messages = [
HumanMessage(content="What's the weather like in San Francisco?"),
AIMessage(
content=[
{"text": "Let's see.", "type": "text"},
{
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"input": {"location": "SF"},
"name": "get_weather",
"type": "tool_use",
},
],
tool_calls=[
{
"name": "get_weather",
"args": {"location": "SF"},
"id": "toolu_01V6d6W32QGGSmQm4BT98EKk",
"type": "tool_call",
},
],
),
ToolMessage(content="Sunny", tool_call_id="toolu_01V6d6W32QGGSmQm4BT98EKk"),
]
num_tokens = llm.get_num_tokens_from_messages(messages, tools=[get_weather])
assert num_tokens > 0
class GetWeather(BaseModel):
"""Get the current weather in a given location"""

View File

@@ -331,7 +331,7 @@ def dummy_tool() -> BaseTool:
arg1: int = Field(..., description="foo")
arg2: Literal["bar", "baz"] = Field(..., description="one of 'bar', 'baz'")
class DummyFunction(BaseTool):
class DummyFunction(BaseTool): # type: ignore[override]
args_schema: Type[BaseModel] = Schema
name: str = "dummy_function"
description: str = "dummy function"

View File

@@ -886,8 +886,13 @@ class BaseChatOpenAI(BaseChatModel):
_, encoding_model = self._get_encoding_model()
return encoding_model.encode(text)
# TODO: Count bound tools as part of input.
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: List[BaseMessage],
tools: Optional[
Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]]
] = None,
) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
**Requirements**: You must have the ``pillow`` installed if you want to count
@@ -897,7 +902,16 @@ class BaseChatOpenAI(BaseChatModel):
counting.
OpenAI reference: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
"""
# TODO: Count bound tools as part of input.
if tools is not None:
warnings.warn("Counting tokens in tool schemas is not yet supported.")
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()