From e4bfc84d6e54b3961a0553266c9cdf5335dbe0d3 Mon Sep 17 00:00:00 2001 From: Chester Curme Date: Tue, 12 Nov 2024 10:29:34 -0500 Subject: [PATCH] update signature in core --- libs/core/langchain_core/language_models/base.py | 11 ++++++++++- libs/core/tests/unit_tests/messages/test_utils.py | 14 +++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 5e60c890a47..43374c9e8f8 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -364,13 +364,22 @@ class BaseLanguageModel( """ return len(self.get_token_ids(text)) - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[Sequence] = None, + ) -> int: """Get the number of tokens in the messages. Useful for checking if an input fits in a model's context window. + **Note**: the base implementation of get_num_tokens_from_messages ignores + tool schemas. + Args: messages: The message inputs to tokenize. + tools: If provided, sequence of dict, BaseModel, function, or BaseTools + to be converted to tool schemas. Returns: The sum of the number of tokens across the messages. diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index ff251e36320..5941c14831a 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1,5 +1,8 @@ import base64 import json +import typing +from collections.abc import Sequence +from typing import Any, Callable, Optional, Union import pytest @@ -19,6 +22,7 @@ from langchain_core.messages.utils import ( merge_message_runs, trim_messages, ) +from langchain_core.tools import BaseTool @pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage]) @@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int: class FakeTokenCountingModel(FakeChatModel): - def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int: + def get_num_tokens_from_messages( + self, + messages: list[BaseMessage], + tools: Optional[ + Sequence[ + Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006 + ] + ] = None, + ) -> int: return dummy_token_counter(messages)