chore(core): improve typing of messages utils functions (#34225)

With this we get the correct types for `_runnable_support` annotated
functions.
* return list[BaseMessage] when messages is not None
* return Runnable when messages is None
* typing of function args
This commit is contained in:
Christophe Bornet
2025-12-08 15:59:43 +01:00
committed by GitHub
parent ba6c2590ae
commit a64aee310c
4 changed files with 76 additions and 35 deletions

View File

@@ -15,12 +15,16 @@ import json
import logging
import math
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Concatenate,
Literal,
ParamSpec,
Protocol,
TypeVar,
cast,
overload,
)
@@ -384,33 +388,54 @@ def convert_to_messages(
return [_convert_to_message(m) for m in messages]
def _runnable_support(func: Callable) -> Callable:
_P = ParamSpec("_P")
_R_co = TypeVar("_R_co", covariant=True)
class _RunnableSupportCallable(Protocol[_P, _R_co]):
@overload
def wrapped(
messages: None = None, **kwargs: Any
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
def __call__(
self,
messages: None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
@overload
def wrapped(
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
) -> list[BaseMessage]: ...
def __call__(
self,
messages: Sequence[MessageLikeRepresentation] | PromptValue,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co: ...
def __call__(
self,
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
def _runnable_support(
func: Callable[
Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
],
) -> _RunnableSupportCallable[_P, _R_co]:
@wraps(func)
def wrapped(
messages: Sequence[MessageLikeRepresentation] | None = None,
**kwargs: Any,
) -> (
list[BaseMessage]
| Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
):
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
*args: _P.args,
**kwargs: _P.kwargs,
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
# Import locally to prevent circular import.
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
if messages is not None:
return func(messages, **kwargs)
return func(messages, *args, **kwargs)
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
wrapped.__doc__ = func.__doc__
return wrapped
return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)
@_runnable_support

View File

@@ -2,10 +2,10 @@ import base64
import json
import re
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, TypedDict
import pytest
from typing_extensions import override
from typing_extensions import NotRequired, override
from langchain_core.language_models.fake_chat_models import FakeChatModel
from langchain_core.messages import (
@@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
assert messages == messages_model_copy
class FilterFields(TypedDict):
include_names: NotRequired[Sequence[str]]
exclude_names: NotRequired[Sequence[str]]
include_types: NotRequired[Sequence[str | type[BaseMessage]]]
exclude_types: NotRequired[Sequence[str | type[BaseMessage]]]
include_ids: NotRequired[Sequence[str]]
exclude_ids: NotRequired[Sequence[str]]
exclude_tool_calls: NotRequired[Sequence[str] | bool]
@pytest.mark.parametrize(
"filters",
[
@@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
],
)
def test_filter_message(filters: dict) -> None:
def test_filter_message(filters: FilterFields) -> None:
messages = [
SystemMessage("foo", name="blah", id="1"),
HumanMessage("bar", name="blur", id="2"),
@@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
assert expected == actual
# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
assert expected == actual
# test excluding a specific tool call
@@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
assert expected == actual
# test explicitly excluding all tool calls
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
assert expected == actual
# test excluding a specific tool call
@@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:
def test_trim_messages_bound_model_token_counter() -> None:
trimmer = trim_messages(
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
max_tokens=10,
token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload]
)
trimmer.invoke([HumanMessage("foobar")])
def test_trim_messages_bad_token_counter() -> None:
trimmer = trim_messages(max_tokens=10, token_counter={})
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
with pytest.raises(
ValueError,
match=re.escape(
@@ -608,7 +619,9 @@ def test_trim_messages_mixed_content_with_partial() -> None:
assert len(result) == 1
assert len(result[0].content) == 1
assert result[0].content[0]["text"] == "First part of text."
content = result[0].content[0]
assert isinstance(content, dict)
assert content["text"] == "First part of text."
assert messages == messages_copy

View File

@@ -516,14 +516,17 @@ class SummarizationMiddleware(AgentMiddleware):
try:
if self.trim_tokens_to_summarize is None:
return messages
return trim_messages(
messages,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
return cast(
"list[AnyMessage]",
trim_messages(
messages,
max_tokens=self.trim_tokens_to_summarize,
token_counter=self.token_counter,
start_on="human",
strategy="last",
allow_partial=True,
include_system=True,
),
)
except Exception: # noqa: BLE001
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]

View File

@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.10.0, <4.0.0"
resolution-markers = [
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
@@ -2166,7 +2166,7 @@ wheels = [
[[package]]
name = "langchain-core"
version = "1.1.0"
version = "1.1.1"
source = { editable = "../core" }
dependencies = [
{ name = "jsonpatch" },