fix(core): disallow_any_generics (#38156)

Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
Christophe Bornet
2026-06-15 15:46:29 +02:00
committed by GitHub
parent b247e572b1
commit afff89a9f7
11 changed files with 36 additions and 27 deletions

View File

@@ -101,7 +101,7 @@ AnyMessage = Annotated[
"""A type representing any defined `Message` or `MessageChunk` type."""
def _has_base64_data(block: dict) -> bool:
def _has_base64_data(block: dict[str, Any]) -> bool:
"""Check if a content block contains base64 encoded data.
Args:
@@ -139,7 +139,7 @@ def _truncate(text: str, max_len: int = _XML_CONTENT_BLOCK_MAX_LEN) -> str:
return text[:max_len] + "..."
def _format_content_block_xml(block: dict) -> str | None:
def _format_content_block_xml(block: dict[str, Any]) -> str | None:
"""Format a content block as XML.
Args:
@@ -581,14 +581,18 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage:
MessageLikeRepresentation = (
BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
BaseMessage
| list[str]
| tuple[str, str | list[str | dict[str, Any]]]
| str
| dict[str, Any]
)
"""A type representing the various ways a message can be represented."""
def _create_message_from_message_type(
message_type: str,
content: str,
content: str | list[str | dict[str, Any]],
name: str | None = None,
tool_call_id: str | None = None,
tool_calls: list[dict[str, Any]] | None = None,
@@ -1534,7 +1538,7 @@ def convert_to_openai_messages(
@overload
def convert_to_openai_messages(
messages: _MultipleMessages,
messages: _MultipleMessages[Any],
*,
text_format: Literal["string", "block"] = "string",
include_id: bool = False,
@@ -1639,12 +1643,13 @@ def convert_to_openai_messages(
oai_messages: list[dict[str, Any]] = []
messages_: Sequence[MessageLikeRepresentation]
if is_single := isinstance(messages, (BaseMessage, dict, str)):
messages = [messages]
messages_ = [messages]
else:
messages_ = cast("Sequence[MessageLikeRepresentation]", messages)
messages = convert_to_messages(messages)
for i, message in enumerate(messages):
for i, message in enumerate(convert_to_messages(messages_)):
oai_msg: dict[str, Any] = {"role": _get_message_openai_role(message)}
tool_messages: list[dict[str, Any]] = []
content: str | list[dict[str, Any]]

View File

@@ -58,7 +58,10 @@ class BasePromptTemplate(
If not provided, all variables are assumed to be strings.
"""
output_parser: BaseOutputParser | None = None
# Ideally we would type output_parser as BaseOutputParser[Any]
# but that makes Pydantic fail (Pydantic tries to instantiate BaseOutputParser
# instead of using the provided output_parser...)
output_parser: BaseOutputParser | None = None # type: ignore[type-arg]
"""How to parse the output of calling an LLM on this formatted prompt."""
partial_variables: Mapping[str, Any] = Field(default_factory=dict)

View File

@@ -389,7 +389,9 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
```
"""
mapper: RunnableParallel
# Ideally we would type mapper as RunnableParallel[dict[str, Any]]
# but this fails validation for Pydantic <2.10
mapper: RunnableParallel # type: ignore[type-arg]
def __init__(self, mapper: RunnableParallel[dict[str, Any]], **kwargs: Any) -> None:
"""Create a `RunnableAssign`.

View File

@@ -92,9 +92,6 @@ strict = true
enable_error_code = "deprecated"
warn_unreachable = true
# TODO: activate for 'strict' checking
disallow_any_generics = false
[tool.ruff.format]
docstring-code-format = true

View File

@@ -1,5 +1,7 @@
"""Test functionality related to length based selector."""
from typing import Any
import pytest
from langchain_core.example_selectors import (
@@ -64,7 +66,7 @@ def test_selector_empty_example(
selector: LengthBasedExampleSelector,
) -> None:
"""Test Empty Example result empty."""
empty_list: list[dict] = []
empty_list: list[dict[str, Any]] = []
empty_selector = LengthBasedExampleSelector(
examples=empty_list,
example_prompt=selector.example_prompt,

View File

@@ -542,7 +542,7 @@ def test_content_blocks_v1_list_content_short_circuits() -> None:
returns it verbatim (the same object) without routing through the
translator. Covers both `AIMessage` and `AIMessageChunk`.
"""
content: list = [
content: list[str | dict[str, Any]] = [
{"type": "text", "text": "Hello"},
{"type": "tool_call", "name": "foo", "args": {"a": 1}, "id": "tc_1"},
]

View File

@@ -765,7 +765,7 @@ class FakeTokenCountingModel(FakeChatModel):
def test_convert_to_messages() -> None:
message_like: list = [
message_like: list[MessageLikeRepresentation] = [
# BaseMessage
SystemMessage("1"),
SystemMessage("1.1", additional_kwargs={"__openai_role__": "developer"}),

View File

@@ -460,7 +460,7 @@ def test_schemas(snapshot: SnapshotAssertion) -> None:
"title": "CommaSeparatedListOutputParserOutput",
}
router: Runnable = RouterRunnable({})
router = RouterRunnable[Any]({})
assert _schema(router.input_schema) == {
"$ref": "#/definitions/RouterInput",
@@ -709,7 +709,7 @@ def test_schema_complex_seq() -> None:
model = FakeListChatModel(responses=[""])
chain1: Runnable = RunnableSequence(
chain1 = RunnableSequence[dict[str, Any], str](
prompt1, model, StrOutputParser(), name="city_chain"
)
@@ -3024,7 +3024,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
input={"question": lambda x: x["question"]},
)
def router(value: dict[str, Any]) -> Runnable:
def router(value: dict[str, Any]) -> Runnable[dict[str, Any], str]:
if value["key"] == "math":
return itemgetter("input") | math_chain
if value["key"] == "english":
@@ -3046,7 +3046,7 @@ async def test_higher_order_lambda_runnable_async(mocker: MockerFixture) -> None
assert result2 == ["4", "2"]
# Test ainvoke
async def arouter(params: dict[str, Any]) -> Runnable:
async def arouter(params: dict[str, Any]) -> Runnable[dict[str, Any], str]:
if params["key"] == "math":
return itemgetter("input") | math_chain
if params["key"] == "english":
@@ -3925,10 +3925,10 @@ def test_each(snapshot: SnapshotAssertion) -> None:
def test_recursive_lambda() -> None:
def _simple_recursion(x: int) -> int | Runnable:
def _simple_recursion(x: int) -> Runnable[Any, int]:
if x < 10:
return RunnableLambda(lambda *_: _simple_recursion(x + 1))
return x
return RunnableLambda(lambda *_: x)
runnable = RunnableLambda(_simple_recursion)
assert runnable.invoke(5) == 10

View File

@@ -309,7 +309,7 @@ class TestRunnableSequenceParallelTraceNesting:
other_thing: Callable[
[int], Generator[int, None, None] | AsyncGenerator[int, None]
],
) -> RunnableLambda:
) -> RunnableLambda[int, int]:
@RunnableLambda
def my_child_function(a: int) -> int:
return a + 2
@@ -611,7 +611,7 @@ def test_traceable_parent_run_map_cleanup_with_sibling_children() -> None:
with tracing_context(client=tracer.client, enabled=True):
@traceable
def parent(x: dict) -> Any:
def parent(x: dict[str, Any]) -> Any:
return chain.invoke(x, config={"callbacks": [tracer]})
result = parent({"input": "hello"})

View File

@@ -3919,7 +3919,7 @@ def test_tool_invoke_returns_list_of_mixin() -> None:
"""End-to-end: a tool returning a list of ToolOutputMixin via invoke."""
@tool
def multi(x: int) -> list:
def multi(x: int) -> list[ToolMessage]:
"""Return multiple outputs."""
return [
ToolMessage(f"result-{i}", tool_call_id=f"sub-{i}", name="multi")

View File

@@ -440,7 +440,7 @@ def test_generation_chunk_addition_type_error() -> None:
],
)
def test_merge_lists(
left: list | None, right: list | None, expected: list | None
left: list[Any] | None, right: list[Any] | None, expected: list[Any] | None
) -> None:
left_copy = deepcopy(left)
right_copy = deepcopy(right)