diff --git a/libs/core/langchain_core/document_loaders/langsmith.py b/libs/core/langchain_core/document_loaders/langsmith.py index 57cac1347c5..33446350b07 100644 --- a/libs/core/langchain_core/document_loaders/langsmith.py +++ b/libs/core/langchain_core/document_loaders/langsmith.py @@ -125,7 +125,7 @@ class LangSmithLoader(BaseLoader): yield Document(content_str, metadata=metadata) -def _stringify(x: Union[str, dict]) -> str: +def _stringify(x: Union[str, dict[str, Any]]) -> str: if isinstance(x, str): return x try: diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 21815d7626c..cb3e18d18b9 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -202,13 +202,17 @@ def message_chunk_to_message(chunk: BaseMessageChunk) -> BaseMessage: MessageLikeRepresentation = Union[ - BaseMessage, list[str], tuple[str, str], str, dict[str, Any] + BaseMessage, + list[str], + tuple[str, Union[str, list[Union[str, dict[str, Any]]]]], + str, + dict[str, Any], ] def _create_message_from_message_type( message_type: str, - content: str, + content: Union[str, list[Union[str, dict[str, Any]]]], name: Optional[str] = None, tool_call_id: Optional[str] = None, tool_calls: Optional[list[dict[str, Any]]] = None, @@ -218,13 +222,13 @@ def _create_message_from_message_type( """Create a message from a message type and content string. Args: - message_type: (str) the type of the message (e.g., "human", "ai", etc.). - content: (str) the content string. - name: (str) the name of the message. Default is None. - tool_call_id: (str) the tool call id. Default is None. - tool_calls: (list[dict[str, Any]]) the tool calls. Default is None. - id: (str) the id of the message. Default is None. - additional_kwargs: (dict[str, Any]) additional keyword arguments. + message_type: the type of the message (e.g., "human", "ai", etc.). + content: the content string. + name: the name of the message. Default is None. + tool_call_id: the tool call id. Default is None. + tool_calls: the tool calls. Default is None. + id: the id of the message. Default is None. + **additional_kwargs: additional keyword arguments. Returns: a message of the appropriate type. @@ -1004,12 +1008,13 @@ def convert_to_openai_messages( oai_messages: list = [] - if is_single := isinstance(messages, (BaseMessage, dict, str)): - messages = [messages] + messages_ = ( + [messages] + if (is_single := isinstance(messages, (BaseMessage, dict, str, tuple))) + else messages + ) - messages = convert_to_messages(messages) - - for i, message in enumerate(messages): + for i, message in enumerate(convert_to_messages(messages_)): oai_msg: dict = {"role": _get_message_openai_role(message)} tool_messages: list = [] content: Union[str, list[dict]] diff --git a/libs/core/langchain_core/output_parsers/openai_functions.py b/libs/core/langchain_core/output_parsers/openai_functions.py index 708eb5bd81b..a748c773eb2 100644 --- a/libs/core/langchain_core/output_parsers/openai_functions.py +++ b/libs/core/langchain_core/output_parsers/openai_functions.py @@ -225,7 +225,7 @@ class PydanticOutputFunctionsParser(OutputFunctionsParser): @model_validator(mode="before") @classmethod - def validate_schema(cls, values: dict) -> Any: + def validate_schema(cls, values: dict[str, Any]) -> Any: """Validate the pydantic schema. Args: diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index 924762457f3..2e9c62347c5 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -3,9 +3,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from pathlib import Path from typing import ( - TYPE_CHECKING, Annotated, Any, Optional, @@ -51,9 +51,6 @@ from langchain_core.prompts.string import ( from langchain_core.utils import get_colored_text from langchain_core.utils.interactive_env import is_interactive_env -if TYPE_CHECKING: - from collections.abc import Sequence - class MessagesPlaceholder(BaseMessagePromptTemplate): """Prompt template that assumes variable is already list of messages. @@ -772,7 +769,7 @@ MessageLikeRepresentation = Union[ MessageLike, tuple[ Union[str, type], - Union[str, list[dict], list[object]], + Union[str, Sequence[dict], Sequence[object]], ], str, dict[str, Any], @@ -1435,11 +1432,13 @@ def _convert_to_message_template( f" Got: {message}" ) raise ValueError(msg) - message = (message["role"], message["content"]) - if len(message) != 2: - msg = f"Expected 2-tuple of (role, template), got {message}" - raise ValueError(msg) - message_type_str, template = message + message_type_str = message["role"] + template = message["content"] + else: + if len(message) != 2: + msg = f"Expected 2-tuple of (role, template), got {message}" + raise ValueError(msg) + message_type_str, template = message if isinstance(message_type_str, str): message_ = _create_template_from_message_type( message_type_str, template, template_format=template_format diff --git a/libs/core/langchain_core/tracers/evaluation.py b/libs/core/langchain_core/tracers/evaluation.py index c7447ad4d81..71049ef00e8 100644 --- a/libs/core/langchain_core/tracers/evaluation.py +++ b/libs/core/langchain_core/tracers/evaluation.py @@ -100,7 +100,7 @@ class EvaluatorCallbackHandler(BaseTracer): ) else: self.executor = None - self.futures: weakref.WeakSet[Future] = weakref.WeakSet() + self.futures: weakref.WeakSet[Future[None]] = weakref.WeakSet() self.skip_unfinished = skip_unfinished self.project_name = project_name self.logged_eval_results: dict[tuple[str, str], list[EvaluationResult]] = {} diff --git a/libs/core/tests/unit_tests/example_selectors/test_base.py b/libs/core/tests/unit_tests/example_selectors/test_base.py index 54793627987..8f94d005b32 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_base.py +++ b/libs/core/tests/unit_tests/example_selectors/test_base.py @@ -10,7 +10,7 @@ class DummyExampleSelector(BaseExampleSelector): def add_example(self, example: dict[str, str]) -> None: self.example = example - def select_examples(self, input_variables: dict[str, str]) -> list[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]: return [input_variables] diff --git a/libs/core/tests/unit_tests/fake/callbacks.py b/libs/core/tests/unit_tests/fake/callbacks.py index b8ec1778b42..2defad94699 100644 --- a/libs/core/tests/unit_tests/fake/callbacks.py +++ b/libs/core/tests/unit_tests/fake/callbacks.py @@ -276,7 +276,9 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): self.on_retriever_error_common() # Overriding since BaseModel has __deepcopy__ method as well - def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": # type: ignore[override] + def __deepcopy__( + self, memo: Union[dict[int, Any], None] = None + ) -> "FakeCallbackHandler": return self @@ -426,5 +428,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi self.on_text_common() # Overriding since BaseModel has __deepcopy__ method as well - def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": # type: ignore[override] + def __deepcopy__( + self, memo: Union[dict[int, Any], None] = None + ) -> "FakeAsyncCallbackHandler": return self diff --git a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py index 41b197cd8c2..7f30c1eadbb 100644 --- a/libs/core/tests/unit_tests/language_models/chat_models/test_base.py +++ b/libs/core/tests/unit_tests/language_models/chat_models/test_base.py @@ -41,7 +41,7 @@ if TYPE_CHECKING: @pytest.fixture -def messages() -> list: +def messages() -> list[BaseMessage]: return [ SystemMessage(content="You are a test user."), HumanMessage(content="Hello, I am a test user."), @@ -49,14 +49,14 @@ def messages() -> list: @pytest.fixture -def messages_2() -> list: +def messages_2() -> list[BaseMessage]: return [ SystemMessage(content="You are a test user."), HumanMessage(content="Hello, I not a test user."), ] -def test_batch_size(messages: list, messages_2: list) -> None: +def test_batch_size(messages: list[BaseMessage], messages_2: list[BaseMessage]) -> None: # The base endpoint doesn't support native batching, # so we expect batch_size to always be 1 llm = FakeListChatModel(responses=[str(i) for i in range(100)]) @@ -80,7 +80,9 @@ def test_batch_size(messages: list, messages_2: list) -> None: assert (cb.traced_runs[0].extra or {}).get("batch_size") == 1 -async def test_async_batch_size(messages: list, messages_2: list) -> None: +async def test_async_batch_size( + messages: list[BaseMessage], messages_2: list[BaseMessage] +) -> None: llm = FakeListChatModel(responses=[str(i) for i in range(100)]) # The base endpoint doesn't support native batching, # so we expect batch_size to always be 1 @@ -262,7 +264,7 @@ async def test_astream_implementation_uses_astream() -> None: class FakeTracer(BaseTracer): def __init__(self) -> None: super().__init__() - self.traced_run_ids: list = [] + self.traced_run_ids: list[uuid.UUID] = [] def _persist_run(self, run: Run) -> None: """Persist a run.""" @@ -415,7 +417,7 @@ async def test_disable_streaming_no_streaming_model_async( class FakeChatModelStartTracer(FakeTracer): def __init__(self) -> None: super().__init__() - self.messages: list = [] + self.messages: list[list[list[BaseMessage]]] = [] def on_chat_model_start(self, *args: Any, **kwargs: Any) -> Run: _, messages = args diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 0898341dda8..a5ec976ba4b 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -17,6 +17,7 @@ from langchain_core.messages import ( ToolMessage, ) from langchain_core.messages.utils import ( + MessageLikeRepresentation, convert_to_messages, convert_to_openai_messages, count_tokens_approximately, @@ -153,7 +154,7 @@ def test_merge_messages_tool_messages() -> None: {"include_names": ["blah", "blur"], "exclude_types": [SystemMessage]}, ], ) -def test_filter_message(filters: dict) -> None: +def test_filter_message(filters: dict[str, Any]) -> None: messages = [ SystemMessage("foo", name="blah", id="1"), HumanMessage("bar", name="blur", id="2"), @@ -673,7 +674,7 @@ class FakeTokenCountingModel(FakeChatModel): def test_convert_to_messages() -> None: - message_like: list = [ + message_like: list[MessageLikeRepresentation] = [ # BaseMessage SystemMessage("1"), SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}), @@ -1179,7 +1180,7 @@ def test_convert_to_openai_messages_mixed_content_types() -> None: def test_convert_to_openai_messages_developer() -> None: - messages: list = [ + messages: list[MessageLikeRepresentation] = [ SystemMessage("a", additional_kwargs={"__openai_role__": "developer"}), {"role": "developer", "content": "a"}, ] diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 3fbb65c63b0..9c8c341657f 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -17,7 +17,7 @@ from langchain_core.output_parsers.openai_tools import ( ) from langchain_core.outputs import ChatGeneration -STREAMED_MESSAGES: list = [ +STREAMED_MESSAGES = [ AIMessageChunk(content=""), AIMessageChunk( content="", @@ -331,7 +331,7 @@ for message in STREAMED_MESSAGES: STREAMED_MESSAGES_WITH_TOOL_CALLS.append(message) -EXPECTED_STREAMED_JSON = [ +EXPECTED_STREAMED_JSON: list[dict[str, Any]] = [ {}, {"names": ["suz"]}, {"names": ["suzy"]}, @@ -392,7 +392,7 @@ def test_partial_json_output_parser(*, use_tool_calls: bool) -> None: chain = input_iter | JsonOutputToolsParser() actual = list(chain.stream(None)) - expected: list = [[]] + [ + expected: list[list[dict[str, Any]]] = [[]] + [ [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON ] assert actual == expected @@ -404,7 +404,7 @@ async def test_partial_json_output_parser_async(*, use_tool_calls: bool) -> None chain = input_iter | JsonOutputToolsParser() actual = [p async for p in chain.astream(None)] - expected: list = [[]] + [ + expected: list[list[dict[str, Any]]] = [[]] + [ [{"type": "NameCollector", "args": chunk}] for chunk in EXPECTED_STREAMED_JSON ] assert actual == expected @@ -416,7 +416,7 @@ def test_partial_json_output_parser_return_id(*, use_tool_calls: bool) -> None: chain = input_iter | JsonOutputToolsParser(return_id=True) actual = list(chain.stream(None)) - expected: list = [[]] + [ + expected: list[list[dict[str, Any]]] = [[]] + [ [ { "type": "NameCollector", @@ -435,7 +435,9 @@ def test_partial_json_output_key_parser(*, use_tool_calls: bool) -> None: chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") actual = list(chain.stream(None)) - expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + expected: list[list[dict[str, Any]]] = [[]] + [ + [chunk] for chunk in EXPECTED_STREAMED_JSON + ] assert actual == expected @@ -446,7 +448,9 @@ async def test_partial_json_output_parser_key_async(*, use_tool_calls: bool) -> chain = input_iter | JsonOutputKeyToolsParser(key_name="NameCollector") actual = [p async for p in chain.astream(None)] - expected: list = [[]] + [[chunk] for chunk in EXPECTED_STREAMED_JSON] + expected: list[list[dict[str, Any]]] = [[]] + [ + [chunk] for chunk in EXPECTED_STREAMED_JSON + ] assert actual == expected diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 0486878749a..7dfe651622e 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -141,9 +141,7 @@ DEF_EXPECTED_RESULT = TestModel( def test_pydantic_output_parser() -> None: """Test PydanticOutputParser.""" - pydantic_parser: PydanticOutputParser = PydanticOutputParser( - pydantic_object=TestModel - ) + pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) result = pydantic_parser.parse(DEF_RESULT) assert result == DEF_EXPECTED_RESULT @@ -152,9 +150,7 @@ def test_pydantic_output_parser() -> None: def test_pydantic_output_parser_fail() -> None: """Test PydanticOutputParser where completion result fails schema validation.""" - pydantic_parser: PydanticOutputParser = PydanticOutputParser( - pydantic_object=TestModel - ) + pydantic_parser = PydanticOutputParser(pydantic_object=TestModel) with pytest.raises( OutputParserException, match="Failed to parse TestModel from completion" diff --git a/libs/core/tests/unit_tests/outputs/test_chat_generation.py b/libs/core/tests/unit_tests/outputs/test_chat_generation.py index c409a76f0de..d9e8ca9f0b8 100644 --- a/libs/core/tests/unit_tests/outputs/test_chat_generation.py +++ b/libs/core/tests/unit_tests/outputs/test_chat_generation.py @@ -1,4 +1,4 @@ -from typing import Union +from typing import Any, Union import pytest @@ -19,14 +19,14 @@ from langchain_core.outputs import ChatGeneration ], ], ) -def test_msg_with_text(content: Union[str, list]) -> None: +def test_msg_with_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None: expected = "foo" actual = ChatGeneration(message=AIMessage(content=content)).text assert actual == expected @pytest.mark.parametrize("content", [[], [{"tool_use": {}, "type": "tool_use"}]]) -def test_msg_no_text(content: Union[str, list]) -> None: +def test_msg_no_text(content: Union[str, list[Union[str, dict[str, Any]]]]) -> None: expected = "" actual = ChatGeneration(message=AIMessage(content=content)).text assert actual == expected diff --git a/libs/core/tests/unit_tests/prompts/test_chat.py b/libs/core/tests/unit_tests/prompts/test_chat.py index f873deb0877..2703cfdc550 100644 --- a/libs/core/tests/unit_tests/prompts/test_chat.py +++ b/libs/core/tests/unit_tests/prompts/test_chat.py @@ -1,7 +1,7 @@ import re import warnings from pathlib import Path -from typing import Any, Union, cast +from typing import Any, Union import pytest from packaging import version @@ -121,11 +121,10 @@ def test_create_system_message_prompt_template_from_template_partial() -> None: History: {history} """ - json_prompt_instructions: dict = {} graph_analyst_template = SystemMessagePromptTemplate.from_template( template=graph_creator_content, input_variables=["history"], - partial_variables={"instructions": json_prompt_instructions}, + partial_variables={"instructions": {}}, ) assert graph_analyst_template.format(history="history") == SystemMessage( content="\n Your instructions are:\n {}\n History:\n history\n " @@ -973,46 +972,43 @@ def test_chat_tmpl_serdes(snapshot: SnapshotAssertion) -> None: ("system", "You are an AI assistant named {name}."), ("system", [{"text": "You are an AI assistant named {name}."}]), SystemMessagePromptTemplate.from_template("you are {foo}"), - cast( - "tuple", - ( - "human", - [ - "hello", - {"text": "What's in this image?"}, - {"type": "text", "text": "What's in this image?"}, - { - "type": "text", - "text": "What's in this image?", - "cache_control": {"type": "{foo}"}, + ( + "human", + [ + "hello", + {"text": "What's in this image?"}, + {"type": "text", "text": "What's in this image?"}, + { + "type": "text", + "text": "What's in this image?", + "cache_control": {"type": "{foo}"}, + }, + { + "type": "image_url", + "image_url": "data:image/jpeg;base64,{my_image}", + }, + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,{my_image}"}, + }, + {"type": "image_url", "image_url": "{my_other_image}"}, + { + "type": "image_url", + "image_url": { + "url": "{my_other_image}", + "detail": "medium", }, - { - "type": "image_url", - "image_url": "data:image/jpeg;base64,{my_image}", - }, - { - "type": "image_url", - "image_url": {"url": "data:image/jpeg;base64,{my_image}"}, - }, - {"type": "image_url", "image_url": "{my_other_image}"}, - { - "type": "image_url", - "image_url": { - "url": "{my_other_image}", - "detail": "medium", - }, - }, - { - "type": "image_url", - "image_url": {"url": "https://www.langchain.com/image.png"}, - }, - { - "type": "image_url", - "image_url": {"url": ""}, - }, - {"image_url": {"url": ""}}, - ], - ), + }, + { + "type": "image_url", + "image_url": {"url": "https://www.langchain.com/image.png"}, + }, + { + "type": "image_url", + "image_url": {"url": ""}, + }, + {"image_url": {"url": ""}}, + ], ), ("placeholder", "{chat_history}"), MessagesPlaceholder("more_history", optional=False), @@ -1179,7 +1175,7 @@ def test_chat_prompt_template_data_prompt_from_message( cache_control_placeholder: str, source_data_placeholder: str, ) -> None: - prompt: dict = { + prompt: dict[str, Any] = { "type": "image", "source_type": "base64", "data": f"{source_data_placeholder}", diff --git a/libs/core/tests/unit_tests/prompts/test_few_shot.py b/libs/core/tests/unit_tests/prompts/test_few_shot.py index 3e768cf380b..dcae9d08812 100644 --- a/libs/core/tests/unit_tests/prompts/test_few_shot.py +++ b/libs/core/tests/unit_tests/prompts/test_few_shot.py @@ -385,7 +385,7 @@ class AsIsSelector(BaseExampleSelector): raise NotImplementedError @override - def select_examples(self, input_variables: dict[str, str]) -> list[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]: return list(self.examples) @@ -480,11 +480,13 @@ class AsyncAsIsSelector(BaseExampleSelector): def add_example(self, example: dict[str, str]) -> Any: raise NotImplementedError - def select_examples(self, input_variables: dict[str, str]) -> list[dict]: + def select_examples(self, input_variables: dict[str, str]) -> list[dict[str, str]]: raise NotImplementedError @override - async def aselect_examples(self, input_variables: dict[str, str]) -> list[dict]: + async def aselect_examples( + self, input_variables: dict[str, str] + ) -> list[dict[str, str]]: return list(self.examples) diff --git a/libs/core/tests/unit_tests/prompts/test_loading.py b/libs/core/tests/unit_tests/prompts/test_loading.py index 76d75314775..ac0d7510332 100644 --- a/libs/core/tests/unit_tests/prompts/test_loading.py +++ b/libs/core/tests/unit_tests/prompts/test_loading.py @@ -15,7 +15,7 @@ EXAMPLE_DIR = (Path(__file__).parent.parent / "examples").absolute() @contextmanager -def change_directory(dir_path: Path) -> Iterator: +def change_directory(dir_path: Path) -> Iterator[None]: """Change the working directory to the right folder.""" origin = Path().absolute() try: diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index e092eb66581..be281f9a74b 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -265,7 +265,7 @@ def test_prompt_from_template_with_partial_variables() -> None: def test_prompt_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {foo} test." - input_variables: list = [] + input_variables: list[str] = [] with pytest.raises( ValueError, match=re.escape("check for mismatched or missing input parameters from []"), @@ -509,7 +509,7 @@ Your variable again: {{ foo }} def test_prompt_jinja2_missing_input_variables() -> None: """Test error is raised when input variables are not provided.""" template = "This is a {{ foo }} test." - input_variables: list = [] + input_variables: list[str] = [] with pytest.warns(UserWarning, match="Missing variables: {'foo'}"): PromptTemplate( input_variables=input_variables, diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 0b74b37cc6f..112c24b753b 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -14,11 +14,15 @@ from langchain_core.utils.pydantic import is_basemodel_subclass def _fake_runnable( - _: Any, *, schema: Union[dict, type[BaseModel]], value: Any = 42, **_kwargs: Any -) -> Union[BaseModel, dict]: + _: Any, + *, + schema: Union[dict[str, Any], type[BaseModel]], + value: Any = 42, + **_kwargs: Any, +) -> Union[BaseModel, dict[str, Any]]: if isclass(schema) and is_basemodel_subclass(schema): return schema(name="yo", value=value) - params = cast("dict", schema)["parameters"] + params = cast("dict[str, Any]", schema)["parameters"] return {k: 1 if k != "value" else value for k, v in params.items()} diff --git a/libs/core/tests/unit_tests/runnables/test_concurrency.py b/libs/core/tests/unit_tests/runnables/test_concurrency.py index 24d4fad5d23..a18ab172723 100644 --- a/libs/core/tests/unit_tests/runnables/test_concurrency.py +++ b/libs/core/tests/unit_tests/runnables/test_concurrency.py @@ -6,8 +6,7 @@ from typing import Any import pytest -from langchain_core.runnables import RunnableConfig, RunnableLambda -from langchain_core.runnables.base import Runnable +from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda @pytest.mark.asyncio @@ -97,7 +96,7 @@ def test_batch_concurrency() -> None: return f"Completed {x}" - runnable: Runnable = RunnableLambda(tracked_function) + runnable = RunnableLambda(tracked_function) num_tasks = 10 max_concurrency = 3 @@ -129,7 +128,7 @@ def test_batch_as_completed_concurrency() -> None: return f"Completed {x}" - runnable: Runnable = RunnableLambda(tracked_function) + runnable = RunnableLambda(tracked_function) num_tasks = 10 max_concurrency = 3 diff --git a/libs/core/tests/unit_tests/runnables/test_config.py b/libs/core/tests/unit_tests/runnables/test_config.py index dc7f1c5d0ca..5897879d90a 100644 --- a/libs/core/tests/unit_tests/runnables/test_config.py +++ b/libs/core/tests/unit_tests/runnables/test_config.py @@ -26,7 +26,7 @@ from langchain_core.tracers.stdout import ConsoleCallbackHandler def test_ensure_config() -> None: run_id = str(uuid.uuid4()) - arg: dict = { + arg: dict[str, Any] = { "something": "else", "metadata": {"foo": "bar"}, "configurable": {"baz": "qux"}, @@ -147,7 +147,7 @@ async def test_merge_config_callbacks() -> None: def test_config_arbitrary_keys() -> None: base: RunnablePassthrough[Any] = RunnablePassthrough() bound = base.with_config(my_custom_key="my custom value") - config = cast("RunnableBinding", bound).config + config = cast("RunnableBinding[Any, Any]", bound).config assert config.get("my_custom_key") == "my custom value" diff --git a/libs/core/tests/unit_tests/runnables/test_context.py b/libs/core/tests/unit_tests/runnables/test_context.py index b638ac78f57..0756eeaf285 100644 --- a/libs/core/tests/unit_tests/runnables/test_context.py +++ b/libs/core/tests/unit_tests/runnables/test_context.py @@ -332,7 +332,8 @@ test_cases = [ @pytest.mark.parametrize(("runnable", "cases"), test_cases) def test_context_runnables( - runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] + runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]], + cases: list[_TestCase], ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert runnable.invoke(cases[0].input) == cases[0].output @@ -344,7 +345,8 @@ def test_context_runnables( @pytest.mark.parametrize(("runnable", "cases"), test_cases) async def test_context_runnables_async( - runnable: Union[Runnable, Callable[[], Runnable]], cases: list[_TestCase] + runnable: Union[Runnable[Any, Any], Callable[[], Runnable[Any, Any]]], + cases: list[_TestCase], ) -> None: runnable = runnable if isinstance(runnable, Runnable) else runnable() assert await runnable.ainvoke(cases[1].input) == cases[1].output diff --git a/libs/core/tests/unit_tests/runnables/test_fallbacks.py b/libs/core/tests/unit_tests/runnables/test_fallbacks.py index 3985facc544..e0666fbf744 100644 --- a/libs/core/tests/unit_tests/runnables/test_fallbacks.py +++ b/libs/core/tests/unit_tests/runnables/test_fallbacks.py @@ -34,7 +34,7 @@ from langchain_core.tools import BaseTool @pytest.fixture -def llm() -> RunnableWithFallbacks: +def llm() -> RunnableWithFallbacks[Any, Any]: error_llm = FakeListLLM(responses=["foo"], i=1) pass_llm = FakeListLLM(responses=["bar"]) @@ -42,7 +42,7 @@ def llm() -> RunnableWithFallbacks: @pytest.fixture -def llm_multi() -> RunnableWithFallbacks: +def llm_multi() -> RunnableWithFallbacks[Any, Any]: error_llm = FakeListLLM(responses=["foo"], i=1) error_llm_2 = FakeListLLM(responses=["baz"], i=1) pass_llm = FakeListLLM(responses=["bar"]) @@ -51,7 +51,7 @@ def llm_multi() -> RunnableWithFallbacks: @pytest.fixture -def chain() -> Runnable: +def chain() -> Runnable[Any, str]: error_llm = FakeListLLM(responses=["foo"], i=1) pass_llm = FakeListLLM(responses=["bar"]) @@ -61,18 +61,18 @@ def chain() -> Runnable: ) -def _raise_error(_: dict) -> str: +def _raise_error(_: dict[str, Any]) -> str: raise ValueError -def _dont_raise_error(inputs: dict) -> str: +def _dont_raise_error(inputs: dict[str, Any]) -> str: if "exception" in inputs: return "bar" raise ValueError @pytest.fixture -def chain_pass_exceptions() -> Runnable: +def chain_pass_exceptions() -> Runnable[Any, str]: fallback = RunnableLambda(_dont_raise_error) return {"text": RunnablePassthrough()} | RunnableLambda( _raise_error @@ -80,13 +80,13 @@ def chain_pass_exceptions() -> Runnable: @pytest.mark.parametrize( - "runnable", + "runnable_name", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) def test_fallbacks( - runnable: RunnableWithFallbacks, request: Any, snapshot: SnapshotAssertion + runnable_name: str, request: Any, snapshot: SnapshotAssertion ) -> None: - runnable = request.getfixturevalue(runnable) + runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name) assert runnable.invoke("hello") == "bar" assert runnable.batch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(runnable.stream("hello")) == ["bar"] @@ -94,17 +94,17 @@ def test_fallbacks( @pytest.mark.parametrize( - "runnable", + "runnable_name", ["llm", "llm_multi", "chain", "chain_pass_exceptions"], ) -async def test_fallbacks_async(runnable: RunnableWithFallbacks, request: Any) -> None: - runnable = request.getfixturevalue(runnable) +async def test_fallbacks_async(runnable_name: str, request: Any) -> None: + runnable: Runnable[Any, Any] = request.getfixturevalue(runnable_name) assert await runnable.ainvoke("hello") == "bar" assert await runnable.abatch(["hi", "hey", "bye"]) == ["bar"] * 3 assert list(await runnable.ainvoke("hello")) == list("bar") -def _runnable(inputs: dict) -> str: +def _runnable(inputs: dict[str, Any]) -> str: if inputs["text"] == "foo": return "first" if "exception" not in inputs: @@ -117,7 +117,7 @@ def _runnable(inputs: dict) -> str: return "third" -def _assert_potential_error(actual: list, expected: list) -> None: +def _assert_potential_error(actual: list[Any], expected: list[Any]) -> None: for x, y in zip(actual, expected): if isinstance(x, Exception): assert isinstance(y, type(x)) @@ -260,17 +260,17 @@ async def test_abatch() -> None: _assert_potential_error(actual, expected) -def _generate(_: Iterator) -> Iterator[str]: +def _generate(_: Iterator[Any]) -> Iterator[str]: yield from "foo bar" -def _generate_immediate_error(_: Iterator) -> Iterator[str]: +def _generate_immediate_error(_: Iterator[Any]) -> Iterator[str]: msg = "immmediate error" raise ValueError(msg) yield "" -def _generate_delayed_error(_: Iterator) -> Iterator[str]: +def _generate_delayed_error(_: Iterator[Any]) -> Iterator[str]: yield "" msg = "delayed error" raise ValueError(msg) @@ -289,18 +289,18 @@ def test_fallbacks_stream() -> None: list(runnable.stream({})) -async def _agenerate(_: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate(_: AsyncIterator[Any]) -> AsyncIterator[str]: for c in "foo bar": yield c -async def _agenerate_immediate_error(_: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate_immediate_error(_: AsyncIterator[Any]) -> AsyncIterator[str]: msg = "immmediate error" raise ValueError(msg) yield "" -async def _agenerate_delayed_error(_: AsyncIterator) -> AsyncIterator[str]: +async def _agenerate_delayed_error(_: AsyncIterator[Any]) -> AsyncIterator[str]: yield "" msg = "delayed error" raise ValueError(msg) @@ -346,7 +346,7 @@ class FakeStructuredOutputModel(BaseChatModel): @override def with_structured_output( self, schema: Union[dict, type[BaseModel]], **kwargs: Any - ) -> Runnable[LanguageModelInput, Union[dict, BaseModel]]: + ) -> Runnable[Any, dict[str, int]]: return RunnableLambda(lambda _: {"foo": self.foo}) @property diff --git a/libs/core/tests/unit_tests/runnables/test_graph.py b/libs/core/tests/unit_tests/runnables/test_graph.py index 7944d0b1da4..3e2d25f41fa 100644 --- a/libs/core/tests/unit_tests/runnables/test_graph.py +++ b/libs/core/tests/unit_tests/runnables/test_graph.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any, Optional, Union from packaging import version from pydantic import BaseModel @@ -6,6 +6,7 @@ from syrupy.assertion import SnapshotAssertion from typing_extensions import override from langchain_core.language_models import FakeListLLM +from langchain_core.messages import BaseMessage from langchain_core.output_parsers.list import CommaSeparatedListOutputParser from langchain_core.output_parsers.string import StrOutputParser from langchain_core.output_parsers.xml import XMLOutputParser @@ -222,7 +223,7 @@ def test_graph_sequence_map(snapshot: SnapshotAssertion) -> None: str_parser = StrOutputParser() xml_parser = XMLOutputParser() - def conditional_str_parser(value: str) -> Runnable: + def conditional_str_parser(value: str) -> Runnable[Union[BaseMessage, str], str]: if value == "a": return str_parser return xml_parser @@ -528,7 +529,7 @@ def test_graph_mermaid_escape_node_label() -> None: def test_graph_mermaid_duplicate_nodes(snapshot: SnapshotAssertion) -> None: fake_llm = FakeListLLM(responses=["foo", "bar"]) - sequence: Runnable = ( + sequence = ( PromptTemplate.from_template("Hello, {input}") | { "llm1": fake_llm, diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 0a4b5326ff3..7b0c368bab1 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -35,7 +35,7 @@ def test_interfaces() -> None: def _get_get_session_history( *, - store: Optional[dict[str, Any]] = None, + store: Optional[dict[str, InMemoryChatMessageHistory]] = None, ) -> Callable[..., InMemoryChatMessageHistory]: chat_history_store = store if store is not None else {} @@ -54,7 +54,7 @@ def test_input_messages() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: dict = {} + store: dict[str, InMemoryChatMessageHistory] = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = {"configurable": {"session_id": "1"}} @@ -83,7 +83,7 @@ async def test_input_messages_async() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: dict = {} + store: dict[str, InMemoryChatMessageHistory] = {} get_session_history = _get_get_session_history(store=store) with_history = RunnableWithMessageHistory(runnable, get_session_history) config = {"session_id": "1_async"} @@ -489,7 +489,7 @@ def test_get_output_schema() -> None: ) output_type = with_history.get_output_schema() - expected_schema: dict = { + expected_schema: dict[str, Any] = { "title": "RunnableWithChatHistoryOutput", "type": "object", } @@ -842,8 +842,7 @@ def test_get_output_messages_no_value_error() -> None: lambda messages: "you said: " + "\n".join(str(m.content) for m in messages if isinstance(m, HumanMessage)) ) - store: dict = {} - get_session_history = _get_get_session_history(store=store) + get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory(runnable, get_session_history) config: RunnableConfig = { "configurable": {"session_id": "1", "message_history": get_session_history("1")} @@ -859,8 +858,7 @@ def test_get_output_messages_no_value_error() -> None: def test_get_output_messages_with_value_error() -> None: illegal_bool_message = False runnable = _RunnableLambdaWithRaiseError(lambda _: illegal_bool_message) - store: dict = {} - get_session_history = _get_get_session_history(store=store) + get_session_history = _get_get_session_history() with_history = RunnableWithMessageHistory(runnable, get_session_history) # type: ignore[arg-type] config: RunnableConfig = { "configurable": {"session_id": "1", "message_history": get_session_history("1")} diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 4b63fb50ae2..920b1ff9af1 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -513,7 +513,7 @@ def test_passthrough_assign_schema() -> None: prompt = PromptTemplate.from_template("{context} {question}") fake_llm = FakeListLLM(responses=["a"]) # str -> list[list[str]] - seq_w_assign: Runnable = ( + seq_w_assign = ( RunnablePassthrough.assign(context=itemgetter("question") | retriever) | prompt | fake_llm @@ -530,7 +530,7 @@ def test_passthrough_assign_schema() -> None: "type": "string", } - invalid_seq_w_assign: Runnable = ( + invalid_seq_w_assign = ( RunnablePassthrough.assign(context=itemgetter("question") | retriever) | fake_llm ) @@ -1011,7 +1011,7 @@ def test_passthrough_tap(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() - seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock) assert seq.invoke("hello", my_kwarg="value") == 5 assert mock.call_args_list == [ @@ -1078,7 +1078,7 @@ async def test_passthrough_tap_async(mocker: MockerFixture) -> None: fake = FakeRunnable() mock = mocker.Mock() - seq: Runnable = RunnablePassthrough(mock) | fake | RunnablePassthrough(mock) + seq = RunnablePassthrough[Any](mock) | fake | RunnablePassthrough[Any](mock) assert await seq.ainvoke("hello", my_kwarg="value") == 5 assert mock.call_args_list == [ @@ -1188,8 +1188,8 @@ def test_with_config(mocker: MockerFixture) -> None: ] spy.reset_mock() - fake_1: Runnable = RunnablePassthrough() - fake_2: Runnable = RunnablePassthrough() + fake_1 = RunnablePassthrough[Any]() + fake_2 = RunnablePassthrough[Any]() spy_seq_step = mocker.spy(fake_1.__class__, "invoke") sequence = fake_1.with_config(tags=["a-tag"]) | fake_2.with_config( @@ -1650,7 +1650,7 @@ def test_with_listeners(mocker: MockerFixture) -> None: ) chat = FakeListChatModel(responses=["foo"]) - chain: Runnable = prompt | chat + chain = prompt | chat mock_start = mocker.Mock() mock_end = mocker.Mock() @@ -1683,7 +1683,7 @@ async def test_with_listeners_async(mocker: MockerFixture) -> None: ) chat = FakeListChatModel(responses=["foo"]) - chain: Runnable = prompt | chat + chain = prompt | chat mock_start = mocker.Mock() mock_end = mocker.Mock() @@ -1787,7 +1787,7 @@ def test_prompt_with_chat_model( ) chat = FakeListChatModel(responses=["foo"]) - chain: Runnable = prompt | chat + chain = prompt | chat assert repr(chain) == snapshot assert isinstance(chain, RunnableSequence) @@ -1893,7 +1893,7 @@ async def test_prompt_with_chat_model_async( ) chat = FakeListChatModel(responses=["foo"]) - chain: Runnable = prompt | chat + chain = prompt | chat assert repr(chain) == snapshot assert isinstance(chain, RunnableSequence) @@ -2007,7 +2007,7 @@ async def test_prompt_with_llm( ) llm = FakeListLLM(responses=["foo", "bar"]) - chain: Runnable = prompt | llm + chain = prompt | llm assert isinstance(chain, RunnableSequence) assert chain.first == prompt @@ -2204,7 +2204,7 @@ async def test_prompt_with_llm_parser( llm = FakeStreamingListLLM(responses=["bear, dog, cat", "tomato, lettuce, onion"]) parser = CommaSeparatedListOutputParser() - chain: Runnable = prompt | llm | parser + chain = prompt | llm | parser assert isinstance(chain, RunnableSequence) assert chain.first == prompt @@ -2517,7 +2517,7 @@ async def test_stream_log_lists() -> None: for i in range(4): yield AddableDict(alist=[str(i)]) - chain: Runnable = RunnableGenerator(list_producer) + chain = RunnableGenerator(list_producer) stream_log = [ part async for part in chain.astream_log({"question": "What is your name?"}) @@ -2697,7 +2697,7 @@ def test_combining_sequences( chain2 = cast("RunnableSequence", input_formatter | prompt2 | chat2 | parser2) - assert isinstance(chain, RunnableSequence) + assert isinstance(chain2, RunnableSequence) assert chain2.first == input_formatter assert chain2.middle == [prompt2, chat2] assert chain2.last == parser2 @@ -2705,6 +2705,7 @@ def test_combining_sequences( combined_chain = cast("RunnableSequence", chain | chain2) + assert isinstance(combined_chain, RunnableSequence) assert combined_chain.first == prompt assert combined_chain.middle == [ chat, @@ -2869,13 +2870,13 @@ def test_seq_prompt_dict(mocker: MockerFixture, snapshot: SnapshotAssertion) -> @freeze_time("2023-01-01") def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> None: - chain1: Runnable = ChatPromptTemplate.from_template( + chain1 = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) - chain2: Runnable = ChatPromptTemplate.from_template( + chain2 = ChatPromptTemplate.from_template( "You are an english major. Answer the question: {question}" ) | FakeListLLM(responses=["2"]) - router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + router = RouterRunnable({"math": chain1, "english": chain2}) chain: Runnable = { "key": lambda x: x["key"], "input": {"question": lambda x: x["question"]}, @@ -2913,13 +2914,13 @@ def test_router_runnable(mocker: MockerFixture, snapshot: SnapshotAssertion) -> async def test_router_runnable_async() -> None: - chain1: Runnable = ChatPromptTemplate.from_template( + chain1 = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) - chain2: Runnable = ChatPromptTemplate.from_template( + chain2 = ChatPromptTemplate.from_template( "You are an english major. Answer the question: {question}" ) | FakeListLLM(responses=["2"]) - router: Runnable = RouterRunnable({"math": chain1, "english": chain2}) + router = RouterRunnable({"math": chain1, "english": chain2}) chain: Runnable = { "key": lambda x: x["key"], "input": {"question": lambda x: x["question"]}, @@ -2941,13 +2942,13 @@ async def test_router_runnable_async() -> None: def test_higher_order_lambda_runnable( mocker: MockerFixture, snapshot: SnapshotAssertion ) -> None: - math_chain: Runnable = ChatPromptTemplate.from_template( + math_chain = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) - english_chain: Runnable = ChatPromptTemplate.from_template( + english_chain = ChatPromptTemplate.from_template( "You are an english major. Answer the question: {question}" ) | FakeListLLM(responses=["2"]) - input_map: Runnable = RunnableParallel( + input_map = RunnableParallel( key=lambda x: x["key"], input={"question": lambda x: x["question"]}, ) @@ -2997,13 +2998,13 @@ def test_higher_order_lambda_runnable( async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None: - math_chain: Runnable = ChatPromptTemplate.from_template( + math_chain = ChatPromptTemplate.from_template( "You are a math genius. Answer the question: {question}" ) | FakeListLLM(responses=["4"]) - english_chain: Runnable = ChatPromptTemplate.from_template( + english_chain = ChatPromptTemplate.from_template( "You are an english major. Answer the question: {question}" ) | FakeListLLM(responses=["2"]) - input_map: Runnable = RunnableParallel( + input_map = RunnableParallel( key=lambda x: x["key"], input={"question": lambda x: x["question"]}, ) @@ -3779,7 +3780,7 @@ async def test_deep_astream_assign() -> None: def test_runnable_sequence_transform() -> None: llm = FakeStreamingListLLM(responses=["foo-lish"]) - chain: Runnable = llm | StrOutputParser() + chain = llm | StrOutputParser() stream = chain.transform(llm.stream("Hi there!")) @@ -3792,7 +3793,7 @@ def test_runnable_sequence_transform() -> None: async def test_runnable_sequence_atransform() -> None: llm = FakeStreamingListLLM(responses=["foo-lish"]) - chain: Runnable = llm | StrOutputParser() + chain = llm | StrOutputParser() stream = chain.atransform(llm.astream("Hi there!")) @@ -3867,7 +3868,7 @@ def test_recursive_lambda() -> None: def test_retrying(mocker: MockerFixture) -> None: - def _lambda(x: int) -> Union[int, Runnable]: + def _lambda(x: int) -> int: if x == 1: msg = "x is 1" raise ValueError(msg) @@ -3932,7 +3933,7 @@ def test_retrying(mocker: MockerFixture) -> None: async def test_async_retrying(mocker: MockerFixture) -> None: - def _lambda(x: int) -> Union[int, Runnable]: + def _lambda(x: int) -> int: if x == 1: msg = "x is 1" raise ValueError(msg) @@ -4046,7 +4047,7 @@ async def test_runnable_lambda_astream() -> None: """Test that astream works for both normal functions & those returning Runnable.""" # Wrapper to make a normal function async - def awrapper(func: Callable) -> Callable[..., Awaitable[Any]]: + def awrapper(func: Callable[..., Any]) -> Callable[..., Awaitable[Any]]: async def afunc(*args: Any, **kwargs: Any) -> Any: return func(*args, **kwargs) @@ -4140,8 +4141,8 @@ def test_seq_batch_return_exceptions(mocker: MockerFixture) -> None: def _batch( self, inputs: list[str], - ) -> list: - outputs: list[Any] = [] + ) -> list[Union[str, Exception]]: + outputs: list[Union[str, Exception]] = [] for value in inputs: if value.startswith(self.fail_starts_with): outputs.append( @@ -4281,8 +4282,8 @@ async def test_seq_abatch_return_exceptions(mocker: MockerFixture) -> None: async def _abatch( self, inputs: list[str], - ) -> list: - outputs: list[Any] = [] + ) -> list[Union[str, Exception]]: + outputs: list[Union[str, Exception]] = [] for value in inputs: if value.startswith(self.fail_starts_with): outputs.append( @@ -5534,7 +5535,7 @@ def test_listeners() -> None: from langchain_core.runnables import RunnableLambda from langchain_core.tracers.schemas import Run - def fake_chain(inputs: dict) -> dict: + def fake_chain(inputs: dict[str, str]) -> dict[str, str]: return {**inputs, "key": "extra"} shared_state = {} @@ -5564,7 +5565,7 @@ async def test_listeners_async() -> None: from langchain_core.runnables import RunnableLambda from langchain_core.tracers.schemas import Run - def fake_chain(inputs: dict) -> dict: + def fake_chain(inputs: dict[str, str]) -> dict[str, str]: return {**inputs, "key": "extra"} shared_state = {} @@ -5577,7 +5578,7 @@ async def test_listeners_async() -> None: def on_end(run: Run) -> None: shared_state[run.id]["outputs"] = run.inputs - chain: Runnable = ( + chain = ( RunnableLambda(fake_chain) .with_listeners(on_end=on_end, on_start=on_start) .map() @@ -5647,7 +5648,7 @@ def test_pydantic_protected_namespaces() -> None: with warnings.catch_warnings(): warnings.simplefilter("error") - class CustomChatModel(RunnableSerializable): + class CustomChatModel(RunnableSerializable[str, str]): model_kwargs: dict[str, Any] = Field(default_factory=dict) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index b41409754a3..984126ca149 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -2,7 +2,7 @@ import asyncio import sys -from collections.abc import AsyncIterator, Sequence +from collections.abc import AsyncIterator, Mapping, Sequence from itertools import cycle from typing import Any, cast @@ -44,12 +44,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: return cast("list[StreamEvent]", [{**event, "run_id": ""} for event in events]) -async def _as_async_iterator(iterable: list) -> AsyncIterator: - """Converts an iterable into an async iterator.""" - for item in iterable: - yield item - - async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEvent]: """Collect the events and remove the run ids.""" materialized_events = [event async for event in events] @@ -59,7 +53,9 @@ async def _collect_events(events: AsyncIterator[StreamEvent]) -> list[StreamEven return events_ -def _assert_events_equal_allow_superset_metadata(events: list, expected: list) -> None: +def _assert_events_equal_allow_superset_metadata( + events: Sequence[Mapping[str, Any]], expected: Sequence[Mapping[str, Any]] +) -> None: """Assert that the events are equal.""" assert len(events) == len(expected) for i, (event, expected_event) in enumerate(zip(events, expected)): @@ -1910,7 +1906,7 @@ async def test_runnable_with_message_history() -> None: # Here we use a global variable to store the chat message history. # This will make it easier to inspect it to see the underlying results. - store: dict = {} + store: dict[str, list[BaseMessage]] = {} def get_by_session_id(session_id: str) -> BaseChatMessageHistory: """Get a chat message history.""" diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index a7731053032..72f183a5d66 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -74,12 +74,6 @@ def _with_nulled_run_id(events: Sequence[StreamEvent]) -> list[StreamEvent]: ) -async def _as_async_iterator(iterable: list) -> AsyncIterator: - """Converts an iterable into an async iterator.""" - for item in iterable: - yield item - - async def _collect_events( events: AsyncIterator[StreamEvent], *, with_nulled_ids: bool = True ) -> list[StreamEvent]: @@ -1866,7 +1860,7 @@ async def test_runnable_with_message_history() -> None: # Here we use a global variable to store the chat message history. # This will make it easier to inspect it to see the underlying results. - store: dict = {} + store: dict[str, list[BaseMessage]] = {} def get_by_session_id(session_id: str) -> BaseChatMessageHistory: """Get a chat message history.""" diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 08a846827f3..de555d4e4de 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -20,7 +20,7 @@ from langchain_core.runnables.base import RunnableLambda, RunnableParallel from langchain_core.tracers.langchain import LangChainTracer -def _get_posts(client: Client) -> list: +def _get_posts(client: Client) -> list[dict[str, Any]]: mock_calls = client.session.request.mock_calls # type: ignore[attr-defined] posts = [] for call in mock_calls: @@ -274,7 +274,7 @@ class TestRunnableSequenceParallelTraceNesting: def before(x: int) -> int: return x - def after(x: dict) -> int: + def after(x: dict[str, Any]) -> int: return x["chain_result"] sequence = before | parallel | after diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index af2b603ec35..edd0abd7269 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -1,5 +1,5 @@ import sys -from typing import Callable +from typing import Any, Callable import pytest @@ -22,7 +22,7 @@ from langchain_core.runnables.utils import ( (lambda x: x if x > 0 else 0, "lambda x: x if x > 0 else 0"), # noqa: FURB136 ], ) -def test_get_lambda_source(func: Callable, expected_source: str) -> None: +def test_get_lambda_source(func: Callable[..., Any], expected_source: str) -> None: """Test get_lambda_source function.""" source = get_lambda_source(func) assert source == expected_source diff --git a/libs/core/tests/unit_tests/stores/test_in_memory.py b/libs/core/tests/unit_tests/stores/test_in_memory.py index cc8ec684fbe..6034f59cdab 100644 --- a/libs/core/tests/unit_tests/stores/test_in_memory.py +++ b/libs/core/tests/unit_tests/stores/test_in_memory.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from langchain_tests.integration_tests.base_store import ( BaseStoreAsyncTests, @@ -8,7 +10,7 @@ from langchain_core.stores import InMemoryStore # Check against standard tests -class TestSyncInMemoryStore(BaseStoreSyncTests): +class TestSyncInMemoryStore(BaseStoreSyncTests[Any]): @pytest.fixture def kv_store(self) -> InMemoryStore: return InMemoryStore() diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index f5cc42e1b4d..c2fbba2e8ef 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -39,7 +39,6 @@ from langchain_core.messages import ToolCall, ToolMessage from langchain_core.messages.tool import ToolOutputMixin from langchain_core.retrievers import BaseRetriever from langchain_core.runnables import ( - Runnable, RunnableConfig, RunnableLambda, ensure_config, @@ -72,7 +71,7 @@ from tests.unit_tests.fake.callbacks import FakeCallbackHandler from tests.unit_tests.pydantic_utils import _schema -def _get_tool_call_json_schema(tool: BaseTool) -> dict: +def _get_tool_call_json_schema(tool: BaseTool) -> dict[str, Any]: tool_schema = tool.tool_call_schema if isinstance(tool_schema, dict): return tool_schema @@ -1402,15 +1401,15 @@ class _MockStructuredToolWithRawOutput(BaseTool): self, arg1: int, arg2: bool, # noqa: FBT001 - arg3: Optional[dict] = None, - ) -> tuple[str, dict]: + arg3: Optional[dict[str, str]] = None, + ) -> tuple[str, dict[str, Any]]: return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @tool("structured_api", response_format="content_and_artifact") def _mock_structured_tool_with_artifact( - *, arg1: int, arg2: bool, arg3: Optional[dict] = None -) -> tuple[str, dict]: + *, arg1: int, arg2: bool, arg3: Optional[dict[str, str]] = None +) -> tuple[str, dict[str, Any]]: """A Structured Tool.""" return f"{arg1} {arg2}", {"arg1": arg1, "arg2": arg2, "arg3": arg3} @@ -1419,7 +1418,7 @@ def _mock_structured_tool_with_artifact( "tool", [_MockStructuredToolWithRawOutput(), _mock_structured_tool_with_artifact] ) def test_tool_call_input_tool_message_with_artifact(tool: BaseTool) -> None: - tool_call: dict = { + tool_call: dict[str, Any] = { "name": "structured_api", "args": {"arg1": 1, "arg2": True, "arg3": {"img": "base64string..."}}, "id": "123", @@ -1448,7 +1447,7 @@ def test_convert_from_runnable_dict() -> None: def f(x: Args) -> str: return str(x["a"] * max(x["b"])) - runnable: Runnable = RunnableLambda(f) + runnable = RunnableLambda(f) as_tool = runnable.as_tool() args_schema = as_tool.args_schema assert args_schema is not None @@ -1480,14 +1479,14 @@ def test_convert_from_runnable_dict() -> None: a: int = Field(..., description="Integer") b: list[int] = Field(..., description="List of ints") - runnable = RunnableLambda(g) - as_tool = runnable.as_tool(GSchema) - as_tool.invoke({"a": 3, "b": [1, 2]}) + runnable2 = RunnableLambda(g) + as_tool2 = runnable2.as_tool(GSchema) + as_tool2.invoke({"a": 3, "b": [1, 2]}) # Specify via arg_types: - runnable = RunnableLambda(g) - as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) - result = as_tool.invoke({"a": 3, "b": [1, 2]}) + runnable3 = RunnableLambda(g) + as_tool3 = runnable3.as_tool(arg_types={"a": int, "b": list[int]}) + result = as_tool3.invoke({"a": 3, "b": [1, 2]}) assert result == "6" # Test with config @@ -1496,9 +1495,9 @@ def test_convert_from_runnable_dict() -> None: assert config["configurable"]["foo"] == "not-bar" return str(x["a"] * max(x["b"])) - runnable = RunnableLambda(h) - as_tool = runnable.as_tool(arg_types={"a": int, "b": list[int]}) - result = as_tool.invoke( + runnable4 = RunnableLambda(h) + as_tool4 = runnable4.as_tool(arg_types={"a": int, "b": list[int]}) + result = as_tool4.invoke( {"a": 3, "b": [1, 2]}, config={"configurable": {"foo": "not-bar"}} ) assert result == "6" @@ -1512,7 +1511,7 @@ def test_convert_from_runnable_other() -> None: def g(x: str) -> str: return x + "z" - runnable: Runnable = RunnableLambda(f) | g + runnable = RunnableLambda(f) | g as_tool = runnable.as_tool() args_schema = as_tool.args_schema assert args_schema is None @@ -1527,10 +1526,10 @@ def test_convert_from_runnable_other() -> None: assert config["configurable"]["foo"] == "not-bar" return x + "a" - runnable = RunnableLambda(h) - as_tool = runnable.as_tool() - result = as_tool.invoke("b", config={"configurable": {"foo": "not-bar"}}) - assert result == "ba" + runnable2 = RunnableLambda(h) + as_tool2 = runnable2.as_tool() + result2 = as_tool2.invoke("b", config={"configurable": {"foo": "not-bar"}}) + assert result2 == "ba" @tool("foo", parse_docstring=True) @@ -1785,7 +1784,7 @@ def test_tool_inherited_injected_arg() -> None: } -def _get_parametrized_tools() -> list: +def _get_parametrized_tools() -> list[Callable[..., Any]]: def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: """my_tool.""" return some_tool @@ -1800,7 +1799,7 @@ def _get_parametrized_tools() -> list: @pytest.mark.parametrize("tool_", _get_parametrized_tools()) -def test_fn_injected_arg_with_schema(tool_: Callable) -> None: +def test_fn_injected_arg_with_schema(tool_: Callable[..., Any]) -> None: assert convert_to_openai_function(tool_) == { "name": tool_.__name__, "description": "my_tool.", @@ -2528,13 +2527,13 @@ def test_tool_decorator_description() -> None: assert foo_args_jsons_schema.description == "JSON Schema." assert ( - cast("dict", foo_args_jsons_schema.tool_call_schema)["description"] + cast("dict[str, Any]", foo_args_jsons_schema.tool_call_schema)["description"] == "JSON Schema." ) assert foo_args_jsons_schema_with_description.description == "description" assert ( - cast("dict", foo_args_jsons_schema_with_description.tool_call_schema)[ + cast("dict[str, Any]", foo_args_jsons_schema_with_description.tool_call_schema)[ "description" ] == "description"