update signature in core

This commit is contained in:
Chester Curme 2024-11-12 10:29:34 -05:00
parent 826040f8b8
commit e4bfc84d6e
2 changed files with 23 additions and 2 deletions

View File

@ -364,13 +364,22 @@ class BaseLanguageModel(
"""
return len(self.get_token_ids(text))
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[Sequence] = None,
) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input fits in a model's context window.
**Note**: the base implementation of get_num_tokens_from_messages ignores
tool schemas.
Args:
messages: The message inputs to tokenize.
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
to be converted to tool schemas.
Returns:
The sum of the number of tokens across the messages.

View File

@ -1,5 +1,8 @@
import base64
import json
import typing
from collections.abc import Sequence
from typing import Any, Callable, Optional, Union
import pytest
@ -19,6 +22,7 @@ from langchain_core.messages.utils import (
merge_message_runs,
trim_messages,
)
from langchain_core.tools import BaseTool
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
class FakeTokenCountingModel(FakeChatModel):
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
def get_num_tokens_from_messages(
self,
messages: list[BaseMessage],
tools: Optional[
Sequence[
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
]
] = None,
) -> int:
return dummy_token_counter(messages)