mirror of
https://github.com/hwchase17/langchain.git
synced 2026-03-18 11:07:36 +00:00
chore(core): improve typing of messages utils functions (#34225)
With this we get the correct types for `_runnable_support` annotated functions. * return list[BaseMessage] when messages is not None * return Runnable when messages is None * typing of function args
This commit is contained in:
committed by
GitHub
parent
ba6c2590ae
commit
a64aee310c
@@ -15,12 +15,16 @@ import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Callable, Iterable, Sequence
|
||||
from functools import partial
|
||||
from functools import partial, wraps
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Concatenate,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
@@ -384,33 +388,54 @@ def convert_to_messages(
|
||||
return [_convert_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _runnable_support(func: Callable) -> Callable:
|
||||
_P = ParamSpec("_P")
|
||||
_R_co = TypeVar("_R_co", covariant=True)
|
||||
|
||||
|
||||
class _RunnableSupportCallable(Protocol[_P, _R_co]):
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: None = None, **kwargs: Any
|
||||
) -> Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]: ...
|
||||
def __call__(
|
||||
self,
|
||||
messages: None = None,
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
|
||||
|
||||
@overload
|
||||
def wrapped(
|
||||
messages: Sequence[MessageLikeRepresentation], **kwargs: Any
|
||||
) -> list[BaseMessage]: ...
|
||||
def __call__(
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation] | PromptValue,
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> _R_co: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]: ...
|
||||
|
||||
|
||||
def _runnable_support(
|
||||
func: Callable[
|
||||
Concatenate[Sequence[MessageLikeRepresentation] | PromptValue, _P], _R_co
|
||||
],
|
||||
) -> _RunnableSupportCallable[_P, _R_co]:
|
||||
@wraps(func)
|
||||
def wrapped(
|
||||
messages: Sequence[MessageLikeRepresentation] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> (
|
||||
list[BaseMessage]
|
||||
| Runnable[Sequence[MessageLikeRepresentation], list[BaseMessage]]
|
||||
):
|
||||
messages: Sequence[MessageLikeRepresentation] | PromptValue | None = None,
|
||||
*args: _P.args,
|
||||
**kwargs: _P.kwargs,
|
||||
) -> _R_co | Runnable[Sequence[MessageLikeRepresentation], _R_co]:
|
||||
# Import locally to prevent circular import.
|
||||
from langchain_core.runnables.base import RunnableLambda # noqa: PLC0415
|
||||
|
||||
if messages is not None:
|
||||
return func(messages, **kwargs)
|
||||
return func(messages, *args, **kwargs)
|
||||
return RunnableLambda(partial(func, **kwargs), name=func.__name__)
|
||||
|
||||
wrapped.__doc__ = func.__doc__
|
||||
return wrapped
|
||||
return cast("_RunnableSupportCallable[_P, _R_co]", wrapped)
|
||||
|
||||
|
||||
@_runnable_support
|
||||
|
||||
@@ -2,10 +2,10 @@ import base64
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import pytest
|
||||
from typing_extensions import override
|
||||
from typing_extensions import NotRequired, override
|
||||
|
||||
from langchain_core.language_models.fake_chat_models import FakeChatModel
|
||||
from langchain_core.messages import (
|
||||
@@ -135,6 +135,16 @@ def test_merge_messages_tool_messages() -> None:
|
||||
assert messages == messages_model_copy
|
||||
|
||||
|
||||
class FilterFields(TypedDict):
|
||||
include_names: NotRequired[Sequence[str]]
|
||||
exclude_names: NotRequired[Sequence[str]]
|
||||
include_types: NotRequired[Sequence[str | type[BaseMessage]]]
|
||||
exclude_types: NotRequired[Sequence[str | type[BaseMessage]]]
|
||||
include_ids: NotRequired[Sequence[str]]
|
||||
exclude_ids: NotRequired[Sequence[str]]
|
||||
exclude_tool_calls: NotRequired[Sequence[str] | bool]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filters",
|
||||
[
|
||||
@@ -153,7 +163,7 @@ def test_merge_messages_tool_messages() -> None:
|
||||
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
|
||||
],
|
||||
)
|
||||
def test_filter_message(filters: dict) -> None:
|
||||
def test_filter_message(filters: FilterFields) -> None:
|
||||
messages = [
|
||||
SystemMessage("foo", name="blah", id="1"),
|
||||
HumanMessage("bar", name="blur", id="2"),
|
||||
@@ -192,7 +202,7 @@ def test_filter_message_exclude_tool_calls() -> None:
|
||||
assert expected == actual
|
||||
|
||||
# test explicitly excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
|
||||
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
|
||||
assert expected == actual
|
||||
|
||||
# test excluding a specific tool call
|
||||
@@ -234,7 +244,7 @@ def test_filter_message_exclude_tool_calls_content_blocks() -> None:
|
||||
assert expected == actual
|
||||
|
||||
# test explicitly excluding all tool calls
|
||||
actual = filter_messages(messages, exclude_tool_calls={"1", "2"})
|
||||
actual = filter_messages(messages, exclude_tool_calls=["1", "2"])
|
||||
assert expected == actual
|
||||
|
||||
# test excluding a specific tool call
|
||||
@@ -508,13 +518,14 @@ def test_trim_messages_invoke() -> None:
|
||||
|
||||
def test_trim_messages_bound_model_token_counter() -> None:
|
||||
trimmer = trim_messages(
|
||||
max_tokens=10, token_counter=FakeTokenCountingModel().bind(foo="bar")
|
||||
max_tokens=10,
|
||||
token_counter=FakeTokenCountingModel().bind(foo="bar"), # type: ignore[call-overload]
|
||||
)
|
||||
trimmer.invoke([HumanMessage("foobar")])
|
||||
|
||||
|
||||
def test_trim_messages_bad_token_counter() -> None:
|
||||
trimmer = trim_messages(max_tokens=10, token_counter={})
|
||||
trimmer = trim_messages(max_tokens=10, token_counter={}) # type: ignore[call-overload]
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=re.escape(
|
||||
@@ -608,7 +619,9 @@ def test_trim_messages_mixed_content_with_partial() -> None:
|
||||
|
||||
assert len(result) == 1
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0]["text"] == "First part of text."
|
||||
content = result[0].content[0]
|
||||
assert isinstance(content, dict)
|
||||
assert content["text"] == "First part of text."
|
||||
assert messages == messages_copy
|
||||
|
||||
|
||||
|
||||
@@ -516,14 +516,17 @@ class SummarizationMiddleware(AgentMiddleware):
|
||||
try:
|
||||
if self.trim_tokens_to_summarize is None:
|
||||
return messages
|
||||
return trim_messages(
|
||||
messages,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
allow_partial=True,
|
||||
include_system=True,
|
||||
return cast(
|
||||
"list[AnyMessage]",
|
||||
trim_messages(
|
||||
messages,
|
||||
max_tokens=self.trim_tokens_to_summarize,
|
||||
token_counter=self.token_counter,
|
||||
start_on="human",
|
||||
strategy="last",
|
||||
allow_partial=True,
|
||||
include_system=True,
|
||||
),
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
return messages[-_DEFAULT_FALLBACK_MESSAGE_COUNT:]
|
||||
|
||||
4
libs/langchain_v1/uv.lock
generated
4
libs/langchain_v1/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.10.0, <4.0.0"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and platform_python_implementation == 'PyPy'",
|
||||
@@ -2166,7 +2166,7 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "1.1.0"
|
||||
version = "1.1.1"
|
||||
source = { editable = "../core" }
|
||||
dependencies = [
|
||||
{ name = "jsonpatch" },
|
||||
|
||||
Reference in New Issue
Block a user