core: Fix some missing generic types

This commit is contained in:
cbornet 2025-07-03 14:27:22 +02:00
parent 2fdccd789c
commit 8372d41f70
30 changed files with 238 additions and 233 deletions

View File

@ -125,7 +125,7 @@ class LangSmithLoader(BaseLoader):
yield Document(content_str, metadata=metadata) yield Document(content_str, metadata=metadata)
def _stringify(x: Union[str, dict]) -> str: def _stringify(x: Union[str, dict[str, Any]]) -> str:
if isinstance(x, str): if isinstance(x, str):
return x return x
try: try:

View File

@ -202,13 +202,17 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage:
MessageLikeRepresentation = Union[ MessageLikeRepresentation = Union[
BaseMessage, list[str], tuple[str, str], str, dict[str, Any] BaseMessage,
list[str],
tuple[str, Union[str, list[Union[str, dict[str, Any]]]]],
str,
dict[str, Any],
] ]
def _create_message_from_message_type( def _create_message_from_message_type(
message_type: str, message_type: str,
content: str, content: Union[str, list[Union[str, dict[str, Any]]]],
name: Optional[str] = None, name: Optional[str] = None,
tool_call_id: Optional[str] = None, tool_call_id: Optional[str] = None,
tool_calls: Optional[list[dict[str, Any]]] = None, tool_calls: Optional[list[dict[str, Any]]] = None,
@ -218,13 +222,13 @@ def _create_message_from_message_type(
"""Create a message from a message type and content string. """Create a message from a message type and content string.
Args: Args:
message_type: (str) the type of the message (e.g., "human", "ai", etc.). message_type: the type of the message (e.g., "human", "ai", etc.).
content: (str) the content string. content: the content string.
name: (str) the name of the message. Default is None. name: the name of the message. Default is None.
tool_call_id: (str) the tool call id. Default is None. tool_call_id: the tool call id. Default is None.
tool_calls: (list[dict[str, Any]]) the tool calls. Default is None. tool_calls: the tool calls. Default is None.
id: (str) the id of the message. Default is None. id: the id of the message. Default is None.
additional_kwargs: (dict[str, Any]) additional keyword arguments. **additional_kwargs: additional keyword arguments.
Returns: Returns:
a message of the appropriate type. a message of the appropriate type.
@ -1004,12 +1008,13 @@ def convert_to_openai_messages(
oai_messages: list = [] oai_messages: list = []
if is_single := isinstance(messages, (BaseMessage, dict, str)): messages_ = (
messages = [messages] [messages]
if (is_single := isinstance(messages, (BaseMessage, dict, str, tuple)))
else messages
)
messages = convert_to_messages(messages) for i, message in enumerate(convert_to_messages(messages_)):
for i, message in enumerate(messages):
oai_msg: dict = {"role": _get_message_openai_role(message)} oai_msg: dict = {"role": _get_message_openai_role(message)}
tool_messages: list = [] tool_messages: list = []
content: Union[str, list[dict]] content: Union[str, list[dict]]

View File

@ -225,7 +225,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser):
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def validate_schema(cls, values: dict) -> Any: def validate_schema(cls, values: dict[str, Any]) -> Any:
"""Validate the pydantic schema. """Validate the pydantic schema.
Args: Args:

View File

@ -3,9 +3,9 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
TYPE_CHECKING,
Annotated, Annotated,
Any, Any,
Optional, Optional,
@ -51,9 +51,6 @@ from langchain_core.prompts.string import (
from langchain_core.utils import get_colored_text from langchain_core.utils import get_colored_text
from langchain_core.utils.interactive_env import is_interactive_env from langchain_core.utils.interactive_env import is_interactive_env
if TYPE_CHECKING:
from collections.abc import Sequence
class MessagesPlaceholder(BaseMessagePromptTemplate): class MessagesPlaceholder(BaseMessagePromptTemplate):
"""Prompt template that assumes variable is already list of messages. """Prompt template that assumes variable is already list of messages.
@ -772,7 +769,7 @@ MessageLikeRepresentation = Union[
MessageLike, MessageLike,
tuple[ tuple[
Union[str, type], Union[str, type],
Union[str, list[dict], list[object]], Union[str, Sequence[dict], Sequence[object]],
], ],
str, str,
dict[str, Any], dict[str, Any],
@ -1435,7 +1432,9 @@ def _convert_to_message_template(
f" Got: {message}" f" Got: {message}"
) )
raise ValueError(msg) raise ValueError(msg)
message = (message["role"], message["content"]) message_type_str = message["role"]
template = message["content"]
else:
if len(message) != 2: if len(message) != 2:
msg = f"Expected 2-tuple of (role, template), got {message}" msg = f"Expected 2-tuple of (role, template), got {message}"
raise ValueError(msg) raise ValueError(msg)

View File

@ -100,7 +100,7 @@ class EvaluatorCallbackHandler(BaseTracer):
) )
else: else:
self.executor = None self.executor = None
self.futures: weakref.WeakSet[Future] = weakref.WeakSet() self.futures: weakref.WeakSet[Future[None]] = weakref.WeakSet()
self.skip_unfinished = skip_unfinished self.skip_unfinished = skip_unfinished
self.project_name = project_name self.project_name = project_name
self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {} self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {}

View File

@ -10,7 +10,7 @@ class DummyExampleSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> None: def add_example(self, example: dict[str, str]) -> None:
self.example = example self.example = example
def select_examples(self, input_variables: dict[str, str]) -> list[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
return [input_variables] return [input_variables]

View File

@ -276,7 +276,9 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
self.on_retriever_error_common() self.on_retriever_error_common()
# Overriding since BaseModel has __deepcopy__ method as well # Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override] def __deepcopy__(
self, memo: Union[dict[int, Any], None] = None
) -> "FakeCallbackHandler":
return self return self
@ -426,5 +428,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
self.on_text_common() self.on_text_common()
# Overriding since BaseModel has __deepcopy__ method as well # Overriding since BaseModel has __deepcopy__ method as well
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override] def __deepcopy__(
self, memo: Union[dict[int, Any], None] = None
) -> "FakeAsyncCallbackHandler":
return self return self

View File

@ -41,7 +41,7 @@ if TYPE_CHECKING:
@pytest.fixture @pytest.fixture
def messages() -> list: def messages() -> list[BaseMessage]:
return [ return [
SystemMessage(content="You are a test user."), SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I am a test user."), HumanMessage(content="Hello, I am a test user."),
@ -49,14 +49,14 @@ def messages() -> list:
@pytest.fixture @pytest.fixture
def messages_2() -> list: def messages_2() -> list[BaseMessage]:
return [ return [
SystemMessage(content="You are a test user."), SystemMessage(content="You are a test user."),
HumanMessage(content="Hello, I not a test user."), HumanMessage(content="Hello, I not a test user."),
] ]
def test_batch_size(messages: list, messages_2: list) -> None: def test_batch_size(messages: list[BaseMessage], messages_2: list[BaseMessage]) -> None:
# The base endpoint doesn't support native batching, # The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1 # so we expect batch_size to always be 1
llm = FakeListChatModel(responses=[str(i) for i in range(100)]) llm = FakeListChatModel(responses=[str(i) for i in range(100)])
@ -80,7 +80,9 @@ def test_batch_size(messages: list, messages_2: list) -> None:
assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1
async def test_async_batch_size(messages: list, messages_2: list) -> None: async def test_async_batch_size(
messages: list[BaseMessage], messages_2: list[BaseMessage]
) -> None:
llm = FakeListChatModel(responses=[str(i) for i in range(100)]) llm = FakeListChatModel(responses=[str(i) for i in range(100)])
# The base endpoint doesn't support native batching, # The base endpoint doesn't support native batching,
# so we expect batch_size to always be 1 # so we expect batch_size to always be 1
@ -262,7 +264,7 @@ async def test_astream_implementation_uses_astream() -> None:
class FakeTracer(BaseTracer): class FakeTracer(BaseTracer):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.traced_run_ids: list = [] self.traced_run_ids: list[uuid.UUID] = []
def _persist_run(self, run: Run) -> None: def _persist_run(self, run: Run) -> None:
"""Persist a run.""" """Persist a run."""
@ -415,7 +417,7 @@ async def test_disable_streaming_no_streaming_model_async(
class FakeChatModelStartTracer(FakeTracer): class FakeChatModelStartTracer(FakeTracer):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
self.messages: list = [] self.messages: list[list[list[BaseMessage]]] = []
def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run: def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run:
_, messages = args _, messages = args

View File

@ -17,6 +17,7 @@ from langchain_core.messages import (
ToolMessage, ToolMessage,
) )
from langchain_core.messages.utils import ( from langchain_core.messages.utils import (
MessageLikeRepresentation,
convert_to_messages, convert_to_messages,
convert_to_openai_messages, convert_to_openai_messages,
count_tokens_approximately, count_tokens_approximately,
@ -153,7 +154,7 @@ def test_merge_messages_tool_messages() -> None:
{"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]}, {"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]},
], ],
) )
def test_filter_message(filters: dict) -> None: def test_filter_message(filters: dict[str, Any]) -> None:
messages = [ messages = [
SystemMessage("foo", name="blah", id="1"), SystemMessage("foo", name="blah", id="1"),
HumanMessage("bar", name="blur", id="2"), HumanMessage("bar", name="blur", id="2"),
@ -673,7 +674,7 @@ class FakeTokenCountingModel(FakeChatModel):
def test_convert_to_messages() -> None: def test_convert_to_messages() -> None:
message_like: list = [ message_like: list[MessageLikeRepresentation] = [
# BaseMessage # BaseMessage
SystemMessage("1"), SystemMessage("1"),
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}), SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),
@ -1179,7 +1180,7 @@ def test_convert_to_openai_messages_mixed_content_types() -> None:
def test_convert_to_openai_messages_developer() -> None: def test_convert_to_openai_messages_developer() -> None:
messages: list = [ messages: list[MessageLikeRepresentation] = [
SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}), SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}),
{"role": "developer", "content": "a"}, {"role": "developer", "content": "a"},
] ]

View File

@ -17,7 +17,7 @@ from langchain_core.output_parsers.openai_tools import (
) )
from langchain_core.outputs import ChatGeneration from langchain_core.outputs import ChatGeneration
STREAMED_MESSAGES: list = [ STREAMED_MESSAGES = [
AIMessageChunk(content=""), AIMessageChunk(content=""),
AIMessageChunk( AIMessageChunk(
content="", content="",
@ -331,7 +331,7 @@ for message in STREAMED_MESSAGES:
STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message) STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message)
EXPECTED_STREAMED_JSON = [ EXPECTED_STREAMED_JSON: list[dict[str, Any]] = [
{}, {},
{"names": ["suz"]}, {"names": ["suz"]},
{"names": ["suzy"]}, {"names": ["suzy"]},
@ -392,7 +392,7 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputToolsParser() chain = input_iter | JsonOutputToolsParser()
actual = list(chain.stream(None)) actual = list(chain.stream(None))
expected: list = [[]] + [ expected: list[list[dict[str, Any]]] = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
] ]
assert actual == expected assert actual == expected
@ -404,7 +404,7 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None
chain = input_iter | JsonOutputToolsParser() chain = input_iter | JsonOutputToolsParser()
actual = [p async for p in chain.astream(None)] actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [ expected: list[list[dict[str, Any]]] = [[]] + [
[{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON
] ]
assert actual == expected assert actual == expected
@ -416,7 +416,7 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputToolsParser(return_id=True) chain = input_iter | JsonOutputToolsParser(return_id=True)
actual = list(chain.stream(None)) actual = list(chain.stream(None))
expected: list = [[]] + [ expected: list[list[dict[str, Any]]] = [[]] + [
[ [
{ {
"type": "NameCollector", "type": "NameCollector",
@ -435,7 +435,9 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None:
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = list(chain.stream(None)) actual = list(chain.stream(None))
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] expected: list[list[dict[str, Any]]] = [[]] + [
[chunk] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected assert actual == expected
@ -446,7 +448,9 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) ->
chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector")
actual = [p async for p in chain.astream(None)] actual = [p async for p in chain.astream(None)]
expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] expected: list[list[dict[str, Any]]] = [[]] + [
[chunk] for chunk in EXPECTED_STREAMED_JSON
]
assert actual == expected assert actual == expected

View File

@ -141,9 +141,7 @@ DEF_EXPECTED_RESULT = TestModel(
def test_pydantic_output_parser() -> None: def test_pydantic_output_parser() -> None:
"""Test PydanticOutputParser.""" """Test PydanticOutputParser."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser( pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
pydantic_object=TestModel
)
result = pydantic_parser.parse(DEF_RESULT) result = pydantic_parser.parse(DEF_RESULT)
assert result == DEF_EXPECTED_RESULT assert result == DEF_EXPECTED_RESULT
@ -152,9 +150,7 @@ def test_pydantic_output_parser() -> None:
def test_pydantic_output_parser_fail() -> None: def test_pydantic_output_parser_fail() -> None:
"""Test PydanticOutputParser where completion result fails schema validation.""" """Test PydanticOutputParser where completion result fails schema validation."""
pydantic_parser: PydanticOutputParser = PydanticOutputParser( pydantic_parser = PydanticOutputParser(pydantic_object=TestModel)
pydantic_object=TestModel
)
with pytest.raises( with pytest.raises(
OutputParserException, match="Failed to parse TestModel from completion" OutputParserException, match="Failed to parse TestModel from completion"

View File

@ -1,4 +1,4 @@
from typing import Union from typing import Any, Union
import pytest import pytest
@ -19,14 +19,14 @@ from langchain_core.outputs import ChatGeneration
], ],
], ],
) )
def test_msg_with_text(content: Union[str, list]) -> None: def test_msg_with_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
expected = "foo" expected = "foo"
actual = ChatGeneration(message=AIMessage(content=content)).text actual = ChatGeneration(message=AIMessage(content=content)).text
assert actual == expected assert actual == expected
@pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]]) @pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]])
def test_msg_no_text(content: Union[str, list]) -> None: def test_msg_no_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None:
expected = "" expected = ""
actual = ChatGeneration(message=AIMessage(content=content)).text actual = ChatGeneration(message=AIMessage(content=content)).text
assert actual == expected assert actual == expected

View File

@ -1,7 +1,7 @@
import re import re
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Union, cast from typing import Any, Union
import pytest import pytest
from packaging import version from packaging import version
@ -121,11 +121,10 @@ def test_create_system_message_prompt_template_from_template_partial() -> None:
History: History:
{history} {history}
""" """
json_prompt_instructions: dict = {}
graph_analyst_template = SystemMessagePromptTemplate.from_template( graph_analyst_template = SystemMessagePromptTemplate.from_template(
template=graph_creator_content, template=graph_creator_content,
input_variables=["history"], input_variables=["history"],
partial_variables={"instructions": json_prompt_instructions}, partial_variables={"instructions": {}},
) )
assert graph_analyst_template.format(history="history") == SystemMessage( assert graph_analyst_template.format(history="history") == SystemMessage(
content="\n Your instructions are:\n {}\n History:\n history\n " content="\n Your instructions are:\n {}\n History:\n history\n "
@ -973,8 +972,6 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
("system", "You are an AI assistant named {name}."), ("system", "You are an AI assistant named {name}."),
("system", [{"text": "You are an AI assistant named {name}."}]), ("system", [{"text": "You are an AI assistant named {name}."}]),
SystemMessagePromptTemplate.from_template("you are {foo}"), SystemMessagePromptTemplate.from_template("you are {foo}"),
cast(
"tuple",
( (
"human", "human",
[ [
@ -1013,7 +1010,6 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None:
{"image_url": {"url": "data:image/jpeg;base64,foobar"}}, {"image_url": {"url": "data:image/jpeg;base64,foobar"}},
], ],
), ),
),
("placeholder", "{chat_history}"), ("placeholder", "{chat_history}"),
MessagesPlaceholder("more_history", optional=False), MessagesPlaceholder("more_history", optional=False),
] ]
@ -1179,7 +1175,7 @@ def test_chat_prompt_template_data_prompt_from_message(
cache_control_placeholder: str, cache_control_placeholder: str,
source_data_placeholder: str, source_data_placeholder: str,
) -> None: ) -> None:
prompt: dict = { prompt: dict[str, Any] = {
"type": "image", "type": "image",
"source_type": "base64", "source_type": "base64",
"data": f"{source_data_placeholder}", "data": f"{source_data_placeholder}",

View File

@ -385,7 +385,7 @@ class AsIsSelector(BaseExampleSelector):
raise NotImplementedError raise NotImplementedError
@override @override
def select_examples(self, input_variables: dict[str, str]) -> list[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
return list(self.examples) return list(self.examples)
@ -480,11 +480,13 @@ class AsyncAsIsSelector(BaseExampleSelector):
def add_example(self, example: dict[str, str]) -> Any: def add_example(self, example: dict[str, str]) -> Any:
raise NotImplementedError raise NotImplementedError
def select_examples(self, input_variables: dict[str, str]) -> list[dict]: def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]:
raise NotImplementedError raise NotImplementedError
@override @override
async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: async def aselect_examples(
self, input_variables: dict[str, str]
) -> list[dict[str, str]]:
return list(self.examples) return list(self.examples)

View File

@ -15,7 +15,7 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute()
@contextmanager @contextmanager
def change_directory(dir_path: Path) -> Iterator: def change_directory(dir_path: Path) -> Iterator[None]:
"""Change the working directory to the right folder.""" """Change the working directory to the right folder."""
origin = Path().absolute() origin = Path().absolute()
try: try:

View File

@ -265,7 +265,7 @@ def test_prompt_from_template_with_partial_variables() -> None:
def test_prompt_missing_input_variables() -> None: def test_prompt_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided.""" """Test error is raised when input variables are not provided."""
template = "This is a {foo} test." template = "This is a {foo} test."
input_variables: list = [] input_variables: list[str] = []
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=re.escape("check for mismatched or missing input parameters from []"), match=re.escape("check for mismatched or missing input parameters from []"),
@ -509,7 +509,7 @@ Your variable again: {{ foo }}
def test_prompt_jinja2_missing_input_variables() -> None: def test_prompt_jinja2_missing_input_variables() -> None:
"""Test error is raised when input variables are not provided.""" """Test error is raised when input variables are not provided."""
template = "This is a {{ foo }} test." template = "This is a {{ foo }} test."
input_variables: list = [] input_variables: list[str] = []
with pytest.warns(UserWarning, match="Missing variables: {'foo'}"): with pytest.warns(UserWarning, match="Missing variables: {'foo'}"):
PromptTemplate( PromptTemplate(
input_variables=input_variables, input_variables=input_variables,

View File

@ -14,11 +14,15 @@ from langchain_core.utils.pydantic import is_basemodel_subclass
def _fake_runnable( def _fake_runnable(
_: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any _: Any,
) -> Union[BaseModel, dict]: *,
schema: Union[dict[str, Any], type[BaseModel]],
value: Any = 42,
**_kwargs: Any,
) -> Union[BaseModel, dict[str, Any]]:
if isclass(schema) and is_basemodel_subclass(schema): if isclass(schema) and is_basemodel_subclass(schema):
return schema(name="yo", value=value) return schema(name="yo", value=value)
params = cast("dict", schema)["parameters"] params = cast("dict[str, Any]", schema)["parameters"]
return {k: 1 if k != "value" else value for k, v in params.items()} return {k: 1 if k != "value" else value for k, v in params.items()}

View File

@ -6,8 +6,7 @@ from typing import Any
import pytest import pytest
from langchain_core.runnables import RunnableConfig, RunnableLambda from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.runnables.base import Runnable
@pytest.mark.asyncio @pytest.mark.asyncio
@ -97,7 +96,7 @@ def test_batch_concurrency() -> None:
return f"Completed {x}" return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function) runnable = RunnableLambda(tracked_function)
num_tasks = 10 num_tasks = 10
max_concurrency = 3 max_concurrency = 3
@ -129,7 +128,7 @@ def test_batch_as_completed_concurrency() -> None:
return f"Completed {x}" return f"Completed {x}"
runnable: Runnable = RunnableLambda(tracked_function) runnable = RunnableLambda(tracked_function)
num_tasks = 10 num_tasks = 10
max_concurrency = 3 max_concurrency = 3

View File

@ -26,7 +26,7 @@ from langchain_core.tracers.stdout import ConsoleCallbackHandler
def test_ensure_config() -> None: def test_ensure_config() -> None:
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
arg: dict = { arg: dict[str, Any] = {
"something": "else", "something": "else",
"metadata": {"foo": "bar"}, "metadata": {"foo": "bar"},
"configurable": {"baz": "qux"}, "configurable": {"baz": "qux"},
@ -147,7 +147,7 @@ async def test_merge_config_callbacks() -> None:
def test_config_arbitrary_keys() -> None: def test_config_arbitrary_keys() -> None:
base: RunnablePassthrough[Any] = RunnablePassthrough() base: RunnablePassthrough[Any] = RunnablePassthrough()
bound = base.with_config(my_custom_key="my custom value") bound = base.with_config(my_custom_key="my custom value")
config = cast("RunnableBinding", bound).config config = cast("RunnableBinding[Any, Any]", bound).config
assert config.get("my_custom_key") == "my custom value" assert config.get("my_custom_key") == "my custom value"

View File

@ -332,7 +332,8 @@ test_cases = [
@pytest.mark.parametrize(("runnable", "cases"), test_cases) @pytest.mark.parametrize(("runnable", "cases"), test_cases)
def test_context_runnables( def test_context_runnables(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
cases: list[_TestCase],
) -> None: ) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable() runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert runnable.invoke(cases[0].input) == cases[0].output assert runnable.invoke(cases[0].input) == cases[0].output
@ -344,7 +345,8 @@ def test_context_runnables(
@pytest.mark.parametrize(("runnable", "cases"), test_cases) @pytest.mark.parametrize(("runnable", "cases"), test_cases)
async def test_context_runnables_async( async def test_context_runnables_async(
runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]],
cases: list[_TestCase],
) -> None: ) -> None:
runnable = runnable if isinstance(runnable, Runnable) else runnable() runnable = runnable if isinstance(runnable, Runnable) else runnable()
assert await runnable.ainvoke(cases[1].input) == cases[1].output assert await runnable.ainvoke(cases[1].input) == cases[1].output

View File

@ -34,7 +34,7 @@ from langchain_core.tools import BaseTool
@pytest.fixture @pytest.fixture
def llm() -> RunnableWithFallbacks: def llm() -> RunnableWithFallbacks[Any, Any]:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
@ -42,7 +42,7 @@ def llm() -> RunnableWithFallbacks:
@pytest.fixture @pytest.fixture
def llm_multi() -> RunnableWithFallbacks: def llm_multi() -> RunnableWithFallbacks[Any, Any]:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
error_llm_2 = FakeListLLM(responses=["baz"], i=1) error_llm_2 = FakeListLLM(responses=["baz"], i=1)
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
@ -51,7 +51,7 @@ def llm_multi() -> RunnableWithFallbacks:
@pytest.fixture @pytest.fixture
def chain() -> Runnable: def chain() -> Runnable[Any, str]:
error_llm = FakeListLLM(responses=["foo"], i=1) error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"]) pass_llm = FakeListLLM(responses=["bar"])
@ -61,18 +61,18 @@ def chain() -> Runnable:
) )
def _raise_error(_: dict) -> str: def _raise_error(_: dict[str, Any]) -> str:
raise ValueError raise ValueError
def _dont_raise_error(inputs: dict) -> str: def _dont_raise_error(inputs: dict[str, Any]) -> str:
if "exception" in inputs: if "exception" in inputs:
return "bar" return "bar"
raise ValueError raise ValueError
@pytest.fixture @pytest.fixture
def chain_pass_exceptions() -> Runnable: def chain_pass_exceptions() -> Runnable[Any, str]:
fallback = RunnableLambda(_dont_raise_error) fallback = RunnableLambda(_dont_raise_error)
return {"text": RunnablePassthrough()} | RunnableLambda( return {"text": RunnablePassthrough()} | RunnableLambda(
_raise_error _raise_error
@ -80,13 +80,13 @@ def chain_pass_exceptions() -> Runnable:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"runnable", "runnable_name",
["llm", "llm_multi", "chain", "chain_pass_exceptions"], ["llm", "llm_multi", "chain", "chain_pass_exceptions"],
) )
def test_fallbacks( def test_fallbacks(
runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion runnable_name: str, request: Any, snapshot: SnapshotAssertion
) -> None: ) -> None:
runnable = request.getfixturevalue(runnable) runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
assert runnable.invoke("hello") == "bar" assert runnable.invoke("hello") == "bar"
assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(runnable.stream("hello")) == ["bar"] assert list(runnable.stream("hello")) == ["bar"]
@ -94,17 +94,17 @@ def test_fallbacks(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"runnable", "runnable_name",
["llm", "llm_multi", "chain", "chain_pass_exceptions"], ["llm", "llm_multi", "chain", "chain_pass_exceptions"],
) )
async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None: async def test_fallbacks_async(runnable_name: str, request: Any) -> None:
runnable = request.getfixturevalue(runnable) runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name)
assert await runnable.ainvoke("hello") == "bar" assert await runnable.ainvoke("hello") == "bar"
assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3
assert list(await runnable.ainvoke("hello")) == list("bar") assert list(await runnable.ainvoke("hello")) == list("bar")
def _runnable(inputs: dict) -> str: def _runnable(inputs: dict[str, Any]) -> str:
if inputs["text"] == "foo": if inputs["text"] == "foo":
return "first" return "first"
if "exception" not in inputs: if "exception" not in inputs:
@ -117,7 +117,7 @@ def _runnable(inputs: dict) -> str:
return "third" return "third"
def _assert_potential_error(actual: list, expected: list) -> None: def _assert_potential_error(actual: list[Any], expected: list[Any]) -> None:
for x, y in zip(actual, expected): for x, y in zip(actual, expected):
if isinstance(x, Exception): if isinstance(x, Exception):
assert isinstance(y, type(x)) assert isinstance(y, type(x))
@ -260,17 +260,17 @@ async def test_abatch() -> None:
_assert_potential_error(actual, expected) _assert_potential_error(actual, expected)
def _generate(_: Iterator) -> Iterator[str]: def _generate(_: Iterator[Any]) -> Iterator[str]:
yield from "foo bar" yield from "foo bar"
def _generate_immediate_error(_: Iterator) -> Iterator[str]: def _generate_immediate_error(_: Iterator[Any]) -> Iterator[str]:
msg = "immmediate error" msg = "immmediate error"
raise ValueError(msg) raise ValueError(msg)
yield "" yield ""
def _generate_delayed_error(_: Iterator) -> Iterator[str]: def _generate_delayed_error(_: Iterator[Any]) -> Iterator[str]:
yield "" yield ""
msg = "delayed error" msg = "delayed error"
raise ValueError(msg) raise ValueError(msg)
@ -289,18 +289,18 @@ def test_fallbacks_stream() -> None:
list(runnable.stream({})) list(runnable.stream({}))
async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]: async def _agenerate(_: AsyncIterator[Any]) -> AsyncIterator[str]:
for c in "foo bar": for c in "foo bar":
yield c yield c
async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_immediate_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
msg = "immmediate error" msg = "immmediate error"
raise ValueError(msg) raise ValueError(msg)
yield "" yield ""
async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]: async def _agenerate_delayed_error(_: AsyncIterator[Any]) -> AsyncIterator[str]:
yield "" yield ""
msg = "delayed error" msg = "delayed error"
raise ValueError(msg) raise ValueError(msg)
@ -346,7 +346,7 @@ class FakeStructuredOutputModel(BaseChatModel):
@override @override
def with_structured_output( def with_structured_output(
self, schema: Union[dict, type[BaseModel]], **kwargs: Any self, schema: Union[dict, type[BaseModel]], **kwargs: Any
) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: ) -> Runnable[Any, dict[str, int]]:
return RunnableLambda(lambda _: {"foo": self.foo}) return RunnableLambda(lambda _: {"foo": self.foo})
@property @property

View File

@ -1,4 +1,4 @@
from typing import Any, Optional from typing import Any, Optional, Union
from packaging import version from packaging import version
from pydantic import BaseModel from pydantic import BaseModel
@ -6,6 +6,7 @@ from syrupy.assertion import SnapshotAssertion
from typing_extensions import override from typing_extensions import override
from langchain_core.language_models import FakeListLLM from langchain_core.language_models import FakeListLLM
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.list import CommaSeparatedListOutputParser
from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.xml import XMLOutputParser from langchain_core.output_parsers.xml import XMLOutputParser
@ -222,7 +223,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None:
str_parser = StrOutputParser() str_parser = StrOutputParser()
xml_parser = XMLOutputParser() xml_parser = XMLOutputParser()
def conditional_str_parser(value: str) -> Runnable: def conditional_str_parser(value: str) -> Runnable[Union[BaseMessage, str], str]:
if value == "a": if value == "a":
return str_parser return str_parser
return xml_parser return xml_parser
@ -528,7 +529,7 @@ def test_graph_mermaid_escape_node_label() -> None:
def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None: def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None:
fake_llm = FakeListLLM(responses=["foo", "bar"]) fake_llm = FakeListLLM(responses=["foo", "bar"])
sequence: Runnable = ( sequence = (
PromptTemplate.from_template("Hello, {input}") PromptTemplate.from_template("Hello, {input}")
| { | {
"llm1": fake_llm, "llm1": fake_llm,

View File

@ -35,7 +35,7 @@ def test_interfaces() -> None:
def _get_get_session_history( def _get_get_session_history(
*, *,
store: Optional[dict[str, Any]] = None, store: Optional[dict[str, InMemoryChatMessageHistory]] = None,
) -> Callable[..., InMemoryChatMessageHistory]: ) -> Callable[..., InMemoryChatMessageHistory]:
chat_history_store = store if store is not None else {} chat_history_store = store if store is not None else {}
@ -54,7 +54,7 @@ def test_input_messages() -> None:
lambda messages: "you said: " lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
) )
store: dict = {} store: dict[str, InMemoryChatMessageHistory] = {}
get_session_history = _get_get_session_history(store=store) get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = {"configurable": {"session_id": "1"}} config: RunnableConfig = {"configurable": {"session_id": "1"}}
@ -83,7 +83,7 @@ async def test_input_messages_async() -> None:
lambda messages: "you said: " lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
) )
store: dict = {} store: dict[str, InMemoryChatMessageHistory] = {}
get_session_history = _get_get_session_history(store=store) get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
config = {"session_id": "1_async"} config = {"session_id": "1_async"}
@ -489,7 +489,7 @@ def test_get_output_schema() -> None:
) )
output_type = with_history.get_output_schema() output_type = with_history.get_output_schema()
expected_schema: dict = { expected_schema: dict[str, Any] = {
"title": "RunnableWithChatHistoryOutput", "title": "RunnableWithChatHistoryOutput",
"type": "object", "type": "object",
} }
@ -842,8 +842,7 @@ def test_get_output_messages_no_value_error() -> None:
lambda messages: "you said: " lambda messages: "you said: "
+ "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage))
) )
store: dict = {} get_session_history = _get_get_session_history()
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) with_history = RunnableWithMessageHistory(runnable, get_session_history)
config: RunnableConfig = { config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")} "configurable": {"session_id": "1", "message_history": get_session_history("1")}
@ -859,8 +858,7 @@ def test_get_output_messages_no_value_error() -> None:
def test_get_output_messages_with_value_error() -> None: def test_get_output_messages_with_value_error() -> None:
illegal_bool_message = False illegal_bool_message = False
runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message) runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message)
store: dict = {} get_session_history = _get_get_session_history()
get_session_history = _get_get_session_history(store=store)
with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type] with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type]
config: RunnableConfig = { config: RunnableConfig = {
"configurable": {"session_id": "1", "message_history": get_session_history("1")} "configurable": {"session_id": "1", "message_history": get_session_history("1")}

View File

@ -513,7 +513,7 @@ def test_passthrough_assign_schema() -> None:
prompt = PromptTemplate.from_template("{context} {question}") prompt = PromptTemplate.from_template("{context} {question}")
fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]] fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]]
seq_w_assign: Runnable = ( seq_w_assign = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever) RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| prompt | prompt
| fake_llm | fake_llm
@ -530,7 +530,7 @@ def test_passthrough_assign_schema() -> None:
"type": "string", "type": "string",
} }
invalid_seq_w_assign: Runnable = ( invalid_seq_w_assign = (
RunnablePassthrough.assign(context=itemgetter("question") | retriever) RunnablePassthrough.assign(context=itemgetter("question") | retriever)
| fake_llm | fake_llm
) )
@ -1011,7 +1011,7 @@ def test_passthrough_tap(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
mock = mocker.Mock() mock = mocker.Mock()
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
assert seq.invoke("hello", my_kwarg="value") == 5 assert seq.invoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [ assert mock.call_args_list == [
@ -1078,7 +1078,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None:
fake = FakeRunnable() fake = FakeRunnable()
mock = mocker.Mock() mock = mocker.Mock()
seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock)
assert await seq.ainvoke("hello", my_kwarg="value") == 5 assert await seq.ainvoke("hello", my_kwarg="value") == 5
assert mock.call_args_list == [ assert mock.call_args_list == [
@ -1188,8 +1188,8 @@ def test_with_config(mocker: MockerFixture) -> None:
] ]
spy.reset_mock() spy.reset_mock()
fake_1: Runnable = RunnablePassthrough() fake_1 = RunnablePassthrough[Any]()
fake_2: Runnable = RunnablePassthrough() fake_2 = RunnablePassthrough[Any]()
spy_seq_step = mocker.spy(fake_1.__class__, "invoke") spy_seq_step = mocker.spy(fake_1.__class__, "invoke")
sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config(
@ -1650,7 +1650,7 @@ def test_with_listeners(mocker: MockerFixture) -> None:
) )
chat = FakeListChatModel(responses=["foo"]) chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat chain = prompt | chat
mock_start = mocker.Mock() mock_start = mocker.Mock()
mock_end = mocker.Mock() mock_end = mocker.Mock()
@ -1683,7 +1683,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None:
) )
chat = FakeListChatModel(responses=["foo"]) chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat chain = prompt | chat
mock_start = mocker.Mock() mock_start = mocker.Mock()
mock_end = mocker.Mock() mock_end = mocker.Mock()
@ -1787,7 +1787,7 @@ def test_prompt_with_chat_model(
) )
chat = FakeListChatModel(responses=["foo"]) chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat chain = prompt | chat
assert repr(chain) == snapshot assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
@ -1893,7 +1893,7 @@ async def test_prompt_with_chat_model_async(
) )
chat = FakeListChatModel(responses=["foo"]) chat = FakeListChatModel(responses=["foo"])
chain: Runnable = prompt | chat chain = prompt | chat
assert repr(chain) == snapshot assert repr(chain) == snapshot
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
@ -2007,7 +2007,7 @@ async def test_prompt_with_llm(
) )
llm = FakeListLLM(responses=["foo", "bar"]) llm = FakeListLLM(responses=["foo", "bar"])
chain: Runnable = prompt | llm chain = prompt | llm
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
assert chain.first == prompt assert chain.first == prompt
@ -2204,7 +2204,7 @@ async def test_prompt_with_llm_parser(
llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"]) llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"])
parser = CommaSeparatedListOutputParser() parser = CommaSeparatedListOutputParser()
chain: Runnable = prompt | llm | parser chain = prompt | llm | parser
assert isinstance(chain, RunnableSequence) assert isinstance(chain, RunnableSequence)
assert chain.first == prompt assert chain.first == prompt
@ -2517,7 +2517,7 @@ async def test_stream_log_lists() -> None:
for i in range(4): for i in range(4):
yield AddableDict(alist=[str(i)]) yield AddableDict(alist=[str(i)])
chain: Runnable = RunnableGenerator(list_producer) chain = RunnableGenerator(list_producer)
stream_log = [ stream_log = [
part async for part in chain.astream_log({"question": "What is your name?"}) part async for part in chain.astream_log({"question": "What is your name?"})
@ -2697,7 +2697,7 @@ def test_combining_sequences(
chain2 = cast("RunnableSequence", input_formatter | prompt2 | chat2 | parser2) chain2 = cast("RunnableSequence", input_formatter | prompt2 | chat2 | parser2)
assert isinstance(chain, RunnableSequence) assert isinstance(chain2, RunnableSequence)
assert chain2.first == input_formatter assert chain2.first == input_formatter
assert chain2.middle == [prompt2, chat2] assert chain2.middle == [prompt2, chat2]
assert chain2.last == parser2 assert chain2.last == parser2
@ -2705,6 +2705,7 @@ def test_combining_sequences(
combined_chain = cast("RunnableSequence", chain | chain2) combined_chain = cast("RunnableSequence", chain | chain2)
assert isinstance(combined_chain, RunnableSequence)
assert combined_chain.first == prompt assert combined_chain.first == prompt
assert combined_chain.middle == [ assert combined_chain.middle == [
chat, chat,
@ -2869,13 +2870,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None:
chain1: Runnable = ChatPromptTemplate.from_template( chain1 = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}" "You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"]) ) | FakeListLLM(responses=["4"])
chain2: Runnable = ChatPromptTemplate.from_template( chain2 = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}" "You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"]) ) | FakeListLLM(responses=["2"])
router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) router = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = { chain: Runnable = {
"key": lambda x: x["key"], "key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]}, "input": {"question": lambda x: x["question"]},
@ -2913,13 +2914,13 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) ->
async def test_router_runnable_async() -> None: async def test_router_runnable_async() -> None:
chain1: Runnable = ChatPromptTemplate.from_template( chain1 = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}" "You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"]) ) | FakeListLLM(responses=["4"])
chain2: Runnable = ChatPromptTemplate.from_template( chain2 = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}" "You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"]) ) | FakeListLLM(responses=["2"])
router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) router = RouterRunnable({"math": chain1, "english": chain2})
chain: Runnable = { chain: Runnable = {
"key": lambda x: x["key"], "key": lambda x: x["key"],
"input": {"question": lambda x: x["question"]}, "input": {"question": lambda x: x["question"]},
@ -2941,13 +2942,13 @@ async def test_router_runnable_async() -> None:
def test_higher_order_lambda_runnable( def test_higher_order_lambda_runnable(
mocker: MockerFixture, snapshot: SnapshotAssertion mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None: ) -> None:
math_chain: Runnable = ChatPromptTemplate.from_template( math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}" "You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"]) ) | FakeListLLM(responses=["4"])
english_chain: Runnable = ChatPromptTemplate.from_template( english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}" "You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"]) ) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel( input_map = RunnableParallel(
key=lambda x: x["key"], key=lambda x: x["key"],
input={"question": lambda x: x["question"]}, input={"question": lambda x: x["question"]},
) )
@ -2997,13 +2998,13 @@ def test_higher_order_lambda_runnable(
async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None:
math_chain: Runnable = ChatPromptTemplate.from_template( math_chain = ChatPromptTemplate.from_template(
"You are a math genius. Answer the question: {question}" "You are a math genius. Answer the question: {question}"
) | FakeListLLM(responses=["4"]) ) | FakeListLLM(responses=["4"])
english_chain: Runnable = ChatPromptTemplate.from_template( english_chain = ChatPromptTemplate.from_template(
"You are an english major. Answer the question: {question}" "You are an english major. Answer the question: {question}"
) | FakeListLLM(responses=["2"]) ) | FakeListLLM(responses=["2"])
input_map: Runnable = RunnableParallel( input_map = RunnableParallel(
key=lambda x: x["key"], key=lambda x: x["key"],
input={"question": lambda x: x["question"]}, input={"question": lambda x: x["question"]},
) )
@ -3779,7 +3780,7 @@ async def test_deep_astream_assign() -> None:
def test_runnable_sequence_transform() -> None: def test_runnable_sequence_transform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"]) llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = llm | StrOutputParser() chain = llm | StrOutputParser()
stream = chain.transform(llm.stream("Hi there!")) stream = chain.transform(llm.stream("Hi there!"))
@ -3792,7 +3793,7 @@ def test_runnable_sequence_transform() -> None:
async def test_runnable_sequence_atransform() -> None: async def test_runnable_sequence_atransform() -> None:
llm = FakeStreamingListLLM(responses=["foo-lish"]) llm = FakeStreamingListLLM(responses=["foo-lish"])
chain: Runnable = llm | StrOutputParser() chain = llm | StrOutputParser()
stream = chain.atransform(llm.astream("Hi there!")) stream = chain.atransform(llm.astream("Hi there!"))
@ -3867,7 +3868,7 @@ def test_recursive_lambda() -> None:
def test_retrying(mocker: MockerFixture) -> None: def test_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]: def _lambda(x: int) -> int:
if x == 1: if x == 1:
msg = "x is 1" msg = "x is 1"
raise ValueError(msg) raise ValueError(msg)
@ -3932,7 +3933,7 @@ def test_retrying(mocker: MockerFixture) -> None:
async def test_async_retrying(mocker: MockerFixture) -> None: async def test_async_retrying(mocker: MockerFixture) -> None:
def _lambda(x: int) -> Union[int, Runnable]: def _lambda(x: int) -> int:
if x == 1: if x == 1:
msg = "x is 1" msg = "x is 1"
raise ValueError(msg) raise ValueError(msg)
@ -4046,7 +4047,7 @@ async def test_runnable_lambda_astream() -> None:
"""Test that astream works for both normal functions & those returning Runnable.""" """Test that astream works for both normal functions & those returning Runnable."""
# Wrapper to make a normal function async # Wrapper to make a normal function async
def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]: def awrapper(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
async def afunc(*args: Any, **kwargs: Any) -> Any: async def afunc(*args: Any, **kwargs: Any) -> Any:
return func(*args, **kwargs) return func(*args, **kwargs)
@ -4140,8 +4141,8 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None:
def _batch( def _batch(
self, self,
inputs: list[str], inputs: list[str],
) -> list: ) -> list[Union[str, Exception]]:
outputs: list[Any] = [] outputs: list[Union[str, Exception]] = []
for value in inputs: for value in inputs:
if value.startswith(self.fail_starts_with): if value.startswith(self.fail_starts_with):
outputs.append( outputs.append(
@ -4281,8 +4282,8 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None:
async def _abatch( async def _abatch(
self, self,
inputs: list[str], inputs: list[str],
) -> list: ) -> list[Union[str, Exception]]:
outputs: list[Any] = [] outputs: list[Union[str, Exception]] = []
for value in inputs: for value in inputs:
if value.startswith(self.fail_starts_with): if value.startswith(self.fail_starts_with):
outputs.append( outputs.append(
@ -5534,7 +5535,7 @@ def test_listeners() -> None:
from langchain_core.runnables import RunnableLambda from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict: def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
return {**inputs, "key": "extra"} return {**inputs, "key": "extra"}
shared_state = {} shared_state = {}
@ -5564,7 +5565,7 @@ async def test_listeners_async() -> None:
from langchain_core.runnables import RunnableLambda from langchain_core.runnables import RunnableLambda
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
def fake_chain(inputs: dict) -> dict: def fake_chain(inputs: dict[str, str]) -> dict[str, str]:
return {**inputs, "key": "extra"} return {**inputs, "key": "extra"}
shared_state = {} shared_state = {}
@ -5577,7 +5578,7 @@ async def test_listeners_async() -> None:
def on_end(run: Run) -> None: def on_end(run: Run) -> None:
shared_state[run.id]["outputs"] = run.inputs shared_state[run.id]["outputs"] = run.inputs
chain: Runnable = ( chain = (
RunnableLambda(fake_chain) RunnableLambda(fake_chain)
.with_listeners(on_end=on_end, on_start=on_start) .with_listeners(on_end=on_end, on_start=on_start)
.map() .map()
@ -5647,7 +5648,7 @@ def test_pydantic_protected_namespaces() -> None:
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("error") warnings.simplefilter("error")
class CustomChatModel(RunnableSerializable): class CustomChatModel(RunnableSerializable[str, str]):
model_kwargs: dict[str, Any] = Field(default_factory=dict) model_kwargs: dict[str, Any] = Field(default_factory=dict)

View File

@ -2,7 +2,7 @@
import asyncio import asyncio
import sys import sys
from collections.abc import AsyncIterator, Sequence from collections.abc import AsyncIterator, Mapping, Sequence
from itertools import cycle from itertools import cycle
from typing import Any, cast from typing import Any, cast
@ -44,12 +44,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
return cast("list[StreamEvent]", [{**event, "run_id": ""} for event in events]) return cast("list[StreamEvent]", [{**event, "run_id": ""} for event in events])
async def _as_async_iterator(iterable: list) -> AsyncIterator:
"""Converts an iterable into an async iterator."""
for item in iterable:
yield item
async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]: async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]:
"""Collect the events and remove the run ids.""" """Collect the events and remove the run ids."""
materialized_events = [event async for event in events] materialized_events = [event async for event in events]
@ -59,7 +53,9 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEven
return events_ return events_
def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None: def _assert_events_equal_allow_superset_metadata(
events: Sequence[Mapping[str, Any]], expected: Sequence[Mapping[str, Any]]
) -> None:
"""Assert that the events are equal.""" """Assert that the events are equal."""
assert len(events) == len(expected) assert len(events) == len(expected)
for i, (event, expected_event) in enumerate(zip(events, expected)): for i, (event, expected_event) in enumerate(zip(events, expected)):
@ -1910,7 +1906,7 @@ async def test_runnable_with_message_history() -> None:
# Here we use a global variable to store the chat message history. # Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results. # This will make it easier to inspect it to see the underlying results.
store: dict = {} store: dict[str, list[BaseMessage]] = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory: def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history.""" """Get a chat message history."""

View File

@ -74,12 +74,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]:
) )
async def _as_async_iterator(iterable: list) -> AsyncIterator:
"""Converts an iterable into an async iterator."""
for item in iterable:
yield item
async def _collect_events( async def _collect_events(
events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True
) -> list[StreamEvent]: ) -> list[StreamEvent]:
@ -1866,7 +1860,7 @@ async def test_runnable_with_message_history() -> None:
# Here we use a global variable to store the chat message history. # Here we use a global variable to store the chat message history.
# This will make it easier to inspect it to see the underlying results. # This will make it easier to inspect it to see the underlying results.
store: dict = {} store: dict[str, list[BaseMessage]] = {}
def get_by_session_id(session_id: str) -> BaseChatMessageHistory: def get_by_session_id(session_id: str) -> BaseChatMessageHistory:
"""Get a chat message history.""" """Get a chat message history."""

View File

@ -20,7 +20,7 @@ from langchain_core.runnables.base import RunnableLambda, RunnableParallel
from langchain_core.tracers.langchain import LangChainTracer from langchain_core.tracers.langchain import LangChainTracer
def _get_posts(client: Client) -> list: def _get_posts(client: Client) -> list[dict[str, Any]]:
mock_calls = client.session.request.mock_calls # type: ignore[attr-defined] mock_calls = client.session.request.mock_calls # type: ignore[attr-defined]
posts = [] posts = []
for call in mock_calls: for call in mock_calls:
@ -274,7 +274,7 @@ class TestRunnableSequenceParallelTraceNesting:
def before(x: int) -> int: def before(x: int) -> int:
return x return x
def after(x: dict) -> int: def after(x: dict[str, Any]) -> int:
return x["chain_result"] return x["chain_result"]
sequence = before | parallel | after sequence = before | parallel | after

View File

@ -1,5 +1,5 @@
import sys import sys
from typing import Callable from typing import Any, Callable
import pytest import pytest
@ -22,7 +22,7 @@ from langchain_core.runnables.utils import (
(lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136 (lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136
], ],
) )
def test_get_lambda_source(func: Callable, expected_source: str) -> None: def test_get_lambda_source(func: Callable[..., Any], expected_source: str) -> None:
"""Test get_lambda_source function.""" """Test get_lambda_source function."""
source = get_lambda_source(func) source = get_lambda_source(func)
assert source == expected_source assert source == expected_source

View File

@ -1,3 +1,5 @@
from typing import Any
import pytest import pytest
from langchain_tests.integration_tests.base_store import ( from langchain_tests.integration_tests.base_store import (
BaseStoreAsyncTests, BaseStoreAsyncTests,
@ -8,7 +10,7 @@ from langchain_core.stores import InMemoryStore
# Check against standard tests # Check against standard tests
class TestSyncInMemoryStore(BaseStoreSyncTests): class TestSyncInMemoryStore(BaseStoreSyncTests[Any]):
@pytest.fixture @pytest.fixture
def kv_store(self) -> InMemoryStore: def kv_store(self) -> InMemoryStore:
return InMemoryStore() return InMemoryStore()

View File

@ -39,7 +39,6 @@ from langchain_core.messages import ToolCall, ToolMessage
from langchain_core.messages.tool import ToolOutputMixin from langchain_core.messages.tool import ToolOutputMixin
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import ( from langchain_core.runnables import (
Runnable,
RunnableConfig, RunnableConfig,
RunnableLambda, RunnableLambda,
ensure_config, ensure_config,
@ -72,7 +71,7 @@ from tests.unit_tests.fake.callbacks import FakeCallbackHandler
from tests.unit_tests.pydantic_utils import _schema from tests.unit_tests.pydantic_utils import _schema
def _get_tool_call_json_schema(tool: BaseTool) -> dict: def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]:
tool_schema = tool.tool_call_schema tool_schema = tool.tool_call_schema
if isinstance(tool_schema, dict): if isinstance(tool_schema, dict):
return tool_schema return tool_schema
@ -1402,15 +1401,15 @@ class _MockStructuredToolWithRawOutput(BaseTool):
self, self,
arg1: int, arg1: int,
arg2: bool, # noqa: FBT001 arg2: bool, # noqa: FBT001
arg3: Optional[dict] = None, arg3: Optional[dict[str, str]] = None,
) -> tuple[str, dict]: ) -> tuple[str, dict[str, Any]]:
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@tool("structured_api", response_format="content_and_artifact") @tool("structured_api", response_format="content_and_artifact")
def _mock_structured_tool_with_artifact( def _mock_structured_tool_with_artifact(
*, arg1: int, arg2: bool, arg3: Optional[dict] = None *, arg1: int, arg2: bool, arg3: Optional[dict[str, str]] = None
) -> tuple[str, dict]: ) -> tuple[str, dict[str, Any]]:
"""A Structured Tool.""" """A Structured Tool."""
return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3}
@ -1419,7 +1418,7 @@ def _mock_structured_tool_with_artifact(
"tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact] "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact]
) )
def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None: def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None:
tool_call: dict = { tool_call: dict[str, Any] = {
"name": "structured_api", "name": "structured_api",
"args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}},
"id": "123", "id": "123",
@ -1448,7 +1447,7 @@ def test_convert_from_runnable_dict() -> None:
def f(x: Args) -> str: def f(x: Args) -> str:
return str(x["a"] * max(x["b"])) return str(x["a"] * max(x["b"]))
runnable: Runnable = RunnableLambda(f) runnable = RunnableLambda(f)
as_tool = runnable.as_tool() as_tool = runnable.as_tool()
args_schema = as_tool.args_schema args_schema = as_tool.args_schema
assert args_schema is not None assert args_schema is not None
@ -1480,14 +1479,14 @@ def test_convert_from_runnable_dict() -> None:
a: int = Field(..., description="Integer") a: int = Field(..., description="Integer")
b: list[int] = Field(..., description="List of ints") b: list[int] = Field(..., description="List of ints")
runnable = RunnableLambda(g) runnable2 = RunnableLambda(g)
as_tool = runnable.as_tool(GSchema) as_tool2 = runnable2.as_tool(GSchema)
as_tool.invoke({"a": 3, "b": [1, 2]}) as_tool2.invoke({"a": 3, "b": [1, 2]})
# Specify via arg_types: # Specify via arg_types:
runnable = RunnableLambda(g) runnable3 = RunnableLambda(g)
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool.invoke({"a": 3, "b": [1, 2]}) result = as_tool3.invoke({"a": 3, "b": [1, 2]})
assert result == "6" assert result == "6"
# Test with config # Test with config
@ -1496,9 +1495,9 @@ def test_convert_from_runnable_dict() -> None:
assert config["configurable"]["foo"] == "not-bar" assert config["configurable"]["foo"] == "not-bar"
return str(x["a"] * max(x["b"])) return str(x["a"] * max(x["b"]))
runnable = RunnableLambda(h) runnable4 = RunnableLambda(h)
as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]})
result = as_tool.invoke( result = as_tool4.invoke(
{"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}} {"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}}
) )
assert result == "6" assert result == "6"
@ -1512,7 +1511,7 @@ def test_convert_from_runnable_other() -> None:
def g(x: str) -> str: def g(x: str) -> str:
return x + "z" return x + "z"
runnable: Runnable = RunnableLambda(f) | g runnable = RunnableLambda(f) | g
as_tool = runnable.as_tool() as_tool = runnable.as_tool()
args_schema = as_tool.args_schema args_schema = as_tool.args_schema
assert args_schema is None assert args_schema is None
@ -1527,10 +1526,10 @@ def test_convert_from_runnable_other() -> None:
assert config["configurable"]["foo"] == "not-bar" assert config["configurable"]["foo"] == "not-bar"
return x + "a" return x + "a"
runnable = RunnableLambda(h) runnable2 = RunnableLambda(h)
as_tool = runnable.as_tool() as_tool2 = runnable2.as_tool()
result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}}) result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}})
assert result == "ba" assert result2 == "ba"
@tool("foo", parse_docstring=True) @tool("foo", parse_docstring=True)
@ -1785,7 +1784,7 @@ def test_tool_inherited_injected_arg() -> None:
} }
def _get_parametrized_tools() -> list: def _get_parametrized_tools() -> list[Callable[..., Any]]:
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
"""my_tool.""" """my_tool."""
return some_tool return some_tool
@ -1800,7 +1799,7 @@ def _get_parametrized_tools() -> list:
@pytest.mark.parametrize("tool_", _get_parametrized_tools()) @pytest.mark.parametrize("tool_", _get_parametrized_tools())
def test_fn_injected_arg_with_schema(tool_: Callable) -> None: def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None:
assert convert_to_openai_function(tool_) == { assert convert_to_openai_function(tool_) == {
"name": tool_.__name__, "name": tool_.__name__,
"description": "my_tool.", "description": "my_tool.",
@ -2528,13 +2527,13 @@ def test_tool_decorator_description() -> None:
assert foo_args_jsons_schema.description == "JSON Schema." assert foo_args_jsons_schema.description == "JSON Schema."
assert ( assert (
cast("dict", foo_args_jsons_schema.tool_call_schema)["description"] cast("dict[str, Any]", foo_args_jsons_schema.tool_call_schema)["description"]
== "JSON Schema." == "JSON Schema."
) )
assert foo_args_jsons_schema_with_description.description == "description" assert foo_args_jsons_schema_with_description.description == "description"
assert ( assert (
cast("dict", foo_args_jsons_schema_with_description.tool_call_schema)[ cast("dict[str, Any]", foo_args_jsons_schema_with_description.tool_call_schema)[
"description" "description"
] ]
== "description" == "description"