mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 22:59:05 +00:00
update signature in core
This commit is contained in:
parent
826040f8b8
commit
e4bfc84d6e
@ -364,13 +364,22 @@ class BaseLanguageModel(
|
|||||||
"""
|
"""
|
||||||
return len(self.get_token_ids(text))
|
return len(self.get_token_ids(text))
|
||||||
|
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[Sequence] = None,
|
||||||
|
) -> int:
|
||||||
"""Get the number of tokens in the messages.
|
"""Get the number of tokens in the messages.
|
||||||
|
|
||||||
Useful for checking if an input fits in a model's context window.
|
Useful for checking if an input fits in a model's context window.
|
||||||
|
|
||||||
|
**Note**: the base implementation of get_num_tokens_from_messages ignores
|
||||||
|
tool schemas.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: The message inputs to tokenize.
|
messages: The message inputs to tokenize.
|
||||||
|
tools: If provided, sequence of dict, BaseModel, function, or BaseTools
|
||||||
|
to be converted to tool schemas.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The sum of the number of tokens across the messages.
|
The sum of the number of tokens across the messages.
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
|
import typing
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -19,6 +22,7 @@ from langchain_core.messages.utils import (
|
|||||||
merge_message_runs,
|
merge_message_runs,
|
||||||
trim_messages,
|
trim_messages,
|
||||||
)
|
)
|
||||||
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
@pytest.mark.parametrize("msg_cls", [HumanMessage, AIMessage, SystemMessage])
|
||||||
@ -431,7 +435,15 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int:
|
|||||||
|
|
||||||
|
|
||||||
class FakeTokenCountingModel(FakeChatModel):
|
class FakeTokenCountingModel(FakeChatModel):
|
||||||
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
|
def get_num_tokens_from_messages(
|
||||||
|
self,
|
||||||
|
messages: list[BaseMessage],
|
||||||
|
tools: Optional[
|
||||||
|
Sequence[
|
||||||
|
Union[typing.Dict[str, Any], type, Callable, BaseTool] # noqa: UP006
|
||||||
|
]
|
||||||
|
] = None,
|
||||||
|
) -> int:
|
||||||
return dummy_token_counter(messages)
|
return dummy_token_counter(messages)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user