mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 22:03:52 +00:00
core[patch]: Allow bound models as token_counter in trim_messages (#25563)
This commit is contained in:
parent
e01c6789c4
commit
4bd005adb6
@ -514,6 +514,8 @@ def merge_message_runs(
|
|||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Update so validation errors (for token_counter, for example) are raised on
|
||||||
|
# init not at runtime.
|
||||||
@_runnable_support
|
@_runnable_support
|
||||||
def trim_messages(
|
def trim_messages(
|
||||||
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
messages: Union[Iterable[MessageLikeRepresentation], PromptValue],
|
||||||
@ -759,24 +761,30 @@ def trim_messages(
|
|||||||
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
|
AIMessage("This is a 4 token text. The full message is 10 tokens.", id="fourth"),
|
||||||
]
|
]
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
|
|
||||||
if start_on and strategy == "first":
|
if start_on and strategy == "first":
|
||||||
raise ValueError
|
raise ValueError
|
||||||
if include_system and strategy == "first":
|
if include_system and strategy == "first":
|
||||||
raise ValueError
|
raise ValueError
|
||||||
messages = convert_to_messages(messages)
|
messages = convert_to_messages(messages)
|
||||||
if isinstance(token_counter, BaseLanguageModel):
|
if hasattr(token_counter, "get_num_tokens_from_messages"):
|
||||||
list_token_counter = token_counter.get_num_tokens_from_messages
|
list_token_counter = getattr(token_counter, "get_num_tokens_from_messages")
|
||||||
elif (
|
elif callable(token_counter):
|
||||||
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
if (
|
||||||
is BaseMessage
|
list(inspect.signature(token_counter).parameters.values())[0].annotation
|
||||||
):
|
is BaseMessage
|
||||||
|
):
|
||||||
|
|
||||||
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
def list_token_counter(messages: Sequence[BaseMessage]) -> int:
|
||||||
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc]
|
||||||
|
else:
|
||||||
|
list_token_counter = token_counter # type: ignore[assignment]
|
||||||
else:
|
else:
|
||||||
list_token_counter = token_counter # type: ignore[assignment]
|
raise ValueError(
|
||||||
|
f"'token_counter' expected ot be a model that implements "
|
||||||
|
f"'get_num_tokens_from_messages()' or a function. Received object of type "
|
||||||
|
f"{type(token_counter)}."
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from langchain_text_splitters import TextSplitter
|
from langchain_text_splitters import TextSplitter
|
||||||
|
@ -2,6 +2,7 @@ from typing import Dict, List, Type
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
@ -316,6 +317,19 @@ def test_trim_messages_invoke() -> None:
|
|||||||
assert actual == expected
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_messages_bound_model_token_counter() -> None:
|
||||||
|
trimmer = trim_messages(
|
||||||
|
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
|
||||||
|
)
|
||||||
|
trimmer.invoke([HumanMessage("foobar")])
|
||||||
|
|
||||||
|
|
||||||
|
def test_trim_messages_bad_token_counter() -> None:
|
||||||
|
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
trimmer.invoke([HumanMessage("foobar")])
|
||||||
|
|
||||||
|
|
||||||
def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
||||||
# treat each message like it adds 3 default tokens at the beginning
|
# treat each message like it adds 3 default tokens at the beginning
|
||||||
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
# of the message and at the end of the message. 3 + 4 + 3 = 10 tokens
|
||||||
@ -338,3 +352,8 @@ def dummy_token_counter(messages: List[BaseMessage]) -> int:
|
|||||||
+ default_msg_suffix_len
|
+ default_msg_suffix_len
|
||||||
)
|
)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
class FakeTokenCountingModel(FakeChatModel):
|
||||||
|
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||||
|
return dummy_token_counter(messages)
|
||||||
|
Loading…
Reference in New Issue
Block a user