diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index cc9cfa93a70..2b1a1ba7b6d 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -101,7 +101,7 @@ AnyMessage = Annotated[ """A type representing any defined `Message` or `MessageChunk` type.""" -def _has_base64_data(block: dict) -> bool: +def _has_base64_data(block: dict[str, Any]) -> bool: """Check if a content block contains base64 encoded data. Args: @@ -139,7 +139,7 @@ def _truncate(text: str, max_len: int = _XML_CONTENT_BLOCK_MAX_LEN) -> str: return text[:max_len] + "..." -def _format_content_block_xml(block: dict) -> str | None: +def _format_content_block_xml(block: dict[str, Any]) -> str | None: """Format a content block as XML. Args: @@ -581,14 +581,18 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage: MessageLikeRepresentation = ( - BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any] + BaseMessage + | list[str] + | tuple[str, str | list[str | dict[str, Any]]] + | str + | dict[str, Any] ) """A type representing the various ways a message can be represented.""" def _create_message_from_message_type( message_type: str, - content: str, + content: str | list[str | dict[str, Any]], name: str | None = None, tool_call_id: str | None = None, tool_calls: list[dict[str, Any]] | None = None, @@ -1534,7 +1538,7 @@ def convert_to_openai_messages( @overload def convert_to_openai_messages( - messages: _MultipleMessages, + messages: _MultipleMessages[Any], *, text_format: Literal["string", "block"] = "string", include_id: bool = False, @@ -1639,12 +1643,13 @@ def convert_to_openai_messages( oai_messages: list[dict[str, Any]] = [] + messages_: Sequence[MessageLikeRepresentation] if is_single := isinstance(messages, (BaseMessage, dict, str)): - messages = [messages] + messages_ = [messages] + else: + messages_ = cast("Sequence[MessageLikeRepresentation]", messages) - messages = convert_to_messages(messages) - - for i, message in enumerate(messages): + for i, message in enumerate(convert_to_messages(messages_)): oai_msg: dict[str, Any] = {"role": _get_message_openai_role(message)} tool_messages: list[dict[str, Any]] = [] content: str | list[dict[str, Any]] diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index e90220b796b..1386635aa3d 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -58,7 +58,10 @@ class BasePromptTemplate( If not provided, all variables are assumed to be strings. """ - output_parser: BaseOutputParser | None = None + # Ideally we would type output_parser as BaseOutputParser[Any] + # but that makes Pydantic fail (Pydantic tries to instantiate BaseOutputParser + # instead of using the provided output_parser...) + output_parser: BaseOutputParser | None = None # type: ignore[type-arg] """How to parse the output of calling an LLM on this formatted prompt.""" partial_variables: Mapping[str, Any] = Field(default_factory=dict) diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index c52731ceb72..7c8ca69a112 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -389,7 +389,9 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): ``` """ - mapper: RunnableParallel + # Ideally we would type mapper as RunnableParallel[dict[str, Any]] + # but this fails validation for Pydantic <2.10 + mapper: RunnableParallel # type: ignore[type-arg] def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None: """Create a `RunnableAssign`. diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 63b11114c2b..7dc4b502957 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -92,9 +92,6 @@ strict = true enable_error_code = "deprecated" warn_unreachable = true -# TODO: activate for 'strict' checking -disallow_any_generics = false - [tool.ruff.format] docstring-code-format = true diff --git a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py index f3bfe3b5a91..c29dc2b1db5 100644 --- a/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py +++ b/libs/core/tests/unit_tests/example_selectors/test_length_based_example_selector.py @@ -1,5 +1,7 @@ """Test functionality related to length based selector.""" +from typing import Any + import pytest from langchain_core.example_selectors import ( @@ -64,7 +66,7 @@ def test_selector_empty_example( selector: LengthBasedExampleSelector, ) -> None: """Test Empty Example result empty.""" - empty_list: list[dict] = [] + empty_list: list[dict[str, Any]] = [] empty_selector = LengthBasedExampleSelector( examples=empty_list, example_prompt=selector.example_prompt, diff --git a/libs/core/tests/unit_tests/messages/test_ai.py b/libs/core/tests/unit_tests/messages/test_ai.py index 8e410165f9e..e7d983d6ed4 100644 --- a/libs/core/tests/unit_tests/messages/test_ai.py +++ b/libs/core/tests/unit_tests/messages/test_ai.py @@ -542,7 +542,7 @@ def test_content_blocks_v1_list_content_short_circuits() -> None: returns it verbatim (the same object) without routing through the translator. Covers both `AIMessage` and `AIMessageChunk`. """ - content: list = [ + content: list[str | dict[str, Any]] = [ {"type": "text", "text": "Hello"}, {"type": "tool_call", "name": "foo", "args": {"a": 1}, "id": "tc_1"}, ] diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 07846eb1c20..6a930290958 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -765,7 +765,7 @@ class FakeTokenCountingModel(FakeChatModel): def test_convert_to_messages() -> None: - message_like: list = [ + message_like: list[MessageLikeRepresentation] = [ # BaseMessage SystemMessage("1"), SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}), diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 580b89bc9bb..24b986a2ebf 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -460,7 +460,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None: "title": "CommaSeparatedListOutputParserOutput", } - router: Runnable = RouterRunnable({}) + router = RouterRunnable[Any]({}) assert _schema(router.input_schema) == { "$ref": "#/definitions/RouterInput", @@ -709,7 +709,7 @@ def test_schema_complex_seq() -> None: model = FakeListChatModel(responses=[""]) - chain1: Runnable = RunnableSequence( + chain1 = RunnableSequence[dict[str, Any], str]( prompt1, model, StrOutputParser(), name="city_chain" ) @@ -3024,7 +3024,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None input={"question": lambda x: x["question"]}, ) - def router(value: dict[str, Any]) -> Runnable: + def router(value: dict[str, Any]) -> Runnable[dict[str, Any], str]: if value["key"] == "math": return itemgetter("input") | math_chain if value["key"] == "english": @@ -3046,7 +3046,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None assert result2 == ["4", "2"] # Test ainvoke - async def arouter(params: dict[str, Any]) -> Runnable: + async def arouter(params: dict[str, Any]) -> Runnable[dict[str, Any], str]: if params["key"] == "math": return itemgetter("input") | math_chain if params["key"] == "english": @@ -3925,10 +3925,10 @@ def test_each(snapshot: SnapshotAssertion) -> None: def test_recursive_lambda() -> None: - def _simple_recursion(x: int) -> int | Runnable: + def _simple_recursion(x: int) -> Runnable[Any, int]: if x < 10: return RunnableLambda(lambda *_: _simple_recursion(x + 1)) - return x + return RunnableLambda(lambda *_: x) runnable = RunnableLambda(_simple_recursion) assert runnable.invoke(5) == 10 diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index c1f11d9c2a6..3716ebd0a20 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -309,7 +309,7 @@ class TestRunnableSequenceParallelTraceNesting: other_thing: Callable[ [int], Generator[int, None, None] | AsyncGenerator[int, None] ], - ) -> RunnableLambda: + ) -> RunnableLambda[int, int]: @RunnableLambda def my_child_function(a: int) -> int: return a + 2 @@ -611,7 +611,7 @@ def test_traceable_parent_run_map_cleanup_with_sibling_children() -> None: with tracing_context(client=tracer.client, enabled=True): @traceable - def parent(x: dict) -> Any: + def parent(x: dict[str, Any]) -> Any: return chain.invoke(x, config={"callbacks": [tracer]}) result = parent({"input": "hello"}) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 991a2f6ffb3..140da9cdd74 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -3919,7 +3919,7 @@ def test_tool_invoke_returns_list_of_mixin() -> None: """End-to-end: a tool returning a list of ToolOutputMixin via invoke.""" @tool - def multi(x: int) -> list: + def multi(x: int) -> list[ToolMessage]: """Return multiple outputs.""" return [ ToolMessage(f"result-{i}", tool_call_id=f"sub-{i}", name="multi") diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 7799c9c9df0..95c7cf50a05 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -440,7 +440,7 @@ def test_generation_chunk_addition_type_error() -> None: ], ) def test_merge_lists( - left: list | None, right: list | None, expected: list | None + left: list[Any] | None, right: list[Any] | None, expected: list[Any] | None ) -> None: left_copy = deepcopy(left) right_copy = deepcopy(right)