core: Add mypy strict-equality rule (#31286)

This commit is contained in:
Christophe Bornet 2025-06-02 20:24:35 +02:00 committed by GitHub
parent 2c4e0ab3bc
commit 539e5b6936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 43 additions and 55 deletions

View File

@ -51,7 +51,6 @@ from langchain_core.messages import (
AIMessage, AIMessage,
AnyMessage, AnyMessage,
BaseMessage, BaseMessage,
BaseMessageChunk,
HumanMessage, HumanMessage,
convert_to_messages, convert_to_messages,
convert_to_openai_image_block, convert_to_openai_image_block,
@ -446,13 +445,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*, *,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> Iterator[BaseMessageChunk]: ) -> Iterator[BaseMessage]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
# model doesn't implement streaming, so use default implementation # model doesn't implement streaming, so use default implementation
yield cast( yield self.invoke(input, config=config, stop=stop, **kwargs)
"BaseMessageChunk",
self.invoke(input, config=config, stop=stop, **kwargs),
)
else: else:
config = ensure_config(config) config = ensure_config(config)
messages = self._convert_input(input).to_messages() messages = self._convert_input(input).to_messages()
@ -537,13 +533,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*, *,
stop: Optional[list[str]] = None, stop: Optional[list[str]] = None,
**kwargs: Any, **kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]: ) -> AsyncIterator[BaseMessage]:
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
# No async or sync stream is implemented, so fall back to ainvoke # No async or sync stream is implemented, so fall back to ainvoke
yield cast( yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
"BaseMessageChunk",
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
return return
config = ensure_config(config) config = ensure_config(config)
@ -1454,7 +1447,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
PydanticToolsParser, PydanticToolsParser,
) )
if self.bind_tools is BaseChatModel.bind_tools: if type(self).bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model." msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg) raise NotImplementedError(msg)

View File

@ -4331,8 +4331,9 @@ class RunnableLambda(Runnable[Input, Output]):
self, self,
func: Union[ func: Union[
Union[ Union[
Callable[[Input], Output],
Callable[[Input], Iterator[Output]], Callable[[Input], Iterator[Output]],
Callable[[Input], Runnable[Input, Output]],
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output], Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],

View File

@ -77,6 +77,7 @@ if IS_PYDANTIC_V1:
TypeBaseModel = type[BaseModel] TypeBaseModel = type[BaseModel]
elif IS_PYDANTIC_V2: elif IS_PYDANTIC_V2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
from pydantic.v1.fields import ModelField
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy. # Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc] PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
@ -373,20 +374,20 @@ if IS_PYDANTIC_V2:
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ... def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload @overload
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ... def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
@overload @overload
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ... def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
def get_fields( def get_fields(
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1], model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]: ) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
"""Get the field names of a Pydantic model.""" """Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"): if hasattr(model, "model_fields"):
return model.model_fields return model.model_fields
if hasattr(model, "__fields__"): if hasattr(model, "__fields__"):
return model.__fields__ # type: ignore[return-value] return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}" msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg) raise TypeError(msg)

View File

@ -74,7 +74,6 @@ report_deprecated_as_note = "True"
# TODO: activate for 'strict' checking # TODO: activate for 'strict' checking
disallow_any_generics = "False" disallow_any_generics = "False"
warn_return_any = "False" warn_return_any = "False"
strict_equality = "False"
[tool.ruff] [tool.ruff]

View File

@ -1,5 +1,7 @@
"""Tests for verifying that testing utility code works as expected.""" """Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None:
] ]
assert len({chunk.id for chunk in chunks}) == 1 assert len({chunk.id for chunk in chunks}) == 1
accumulate_chunks = None accumulate_chunks = reduce(operator.add, chunks)
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk( assert accumulate_chunks == AIMessageChunk(
content="", content="",

View File

@ -1,7 +1,7 @@
"""Test PydanticOutputParser.""" """Test PydanticOutputParser."""
from enum import Enum from enum import Enum
from typing import Literal, Optional from typing import Literal, Optional, Union
import pydantic import pydantic
import pytest import pytest
@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel):
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) @pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_pydantic_parser_chaining( def test_pydantic_parser_chaining(
pydantic_object: TBaseModel, pydantic_object: Union[type[ForecastV2], type[ForecastV1]],
) -> None: ) -> None:
prompt = PromptTemplate( prompt = PromptTemplate(
template="""{{ template="""{{
@ -43,11 +43,11 @@ def test_pydantic_parser_chaining(
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
chain = prompt | model | parser chain = prompt | model | parser
res = chain.invoke({}) res = chain.invoke({})
assert type(res) is pydantic_object assert isinstance(res, pydantic_object)
assert res.f_or_c == "C" assert res.f_or_c == "C"
assert res.temperature == 20 assert res.temperature == 20
assert res.forecast == "Sunny" assert res.forecast == "Sunny"

View File

@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None:
template = " {{''.__class__.__bases__[0] }} " # malicious code template = " {{''.__class__.__bases__[0] }} " # malicious code
prompt = PromptTemplate.from_template(template, template_format="jinja2") prompt = PromptTemplate.from_template(template, template_format="jinja2")
with pytest.raises(jinja2.exceptions.SecurityError): with pytest.raises(jinja2.exceptions.SecurityError):
assert prompt.format() == [] prompt.format()
@pytest.mark.requires("jinja2") @pytest.mark.requires("jinja2")

View File

@ -51,7 +51,7 @@ def test_structured_prompt_pydantic() -> None:
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap]
def test_structured_prompt_dict() -> None: def test_structured_prompt_dict() -> None:
@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None:
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
assert loads(dumps(prompt)).model_dump() == prompt.model_dump() assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
chain = loads(dumps(prompt)) | model chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
def test_structured_prompt_kwargs() -> None: def test_structured_prompt_kwargs() -> None:
@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None:
) )
model = FakeStructuredChatModel(responses=[]) model = FakeStructuredChatModel(responses=[])
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
assert loads(dumps(prompt)).model_dump() == prompt.model_dump() assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
chain = loads(dumps(prompt)) | model chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
class OutputSchema(BaseModel): class OutputSchema(BaseModel):
name: str name: str
@ -116,7 +116,7 @@ def test_structured_prompt_kwargs() -> None:
chain = prompt | model chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap]
def test_structured_prompt_template_format() -> None: def test_structured_prompt_template_format() -> None:

View File

@ -671,7 +671,7 @@ def test_with_types_with_type_generics() -> None:
def test_schema_with_itemgetter() -> None: def test_schema_with_itemgetter() -> None:
"""Test runnable with itemgetter.""" """Test runnable with itemgetter."""
foo = RunnableLambda(itemgetter("hello")) foo: Runnable = RunnableLambda(itemgetter("hello"))
assert _schema(foo.input_schema) == { assert _schema(foo.input_schema) == {
"properties": {"hello": {"title": "Hello"}}, "properties": {"hello": {"title": "Hello"}},
"required": ["hello"], "required": ["hello"],
@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None:
# sleep to better simulate a real stream # sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda _: llm).stream("")) output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
assert output == list(llm_res) assert output == list(llm_res)
@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]} config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list( assert list(
llm_res RunnableLambda[str, str](lambda _: llm).stream("", config=config)
) ) == list(llm_res)
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
assert tracer.runs[0].error is None assert tracer.runs[0].error is None
@ -4075,10 +4075,7 @@ async def test_runnable_lambda_astream() -> None:
assert output == list(llm_res) assert output == list(llm_res)
output = [ output = [
chunk chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
async for chunk in cast(
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
)
] ]
assert output == list(llm_res) assert output == list(llm_res)
@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
config: RunnableConfig = {"callbacks": [tracer]} config: RunnableConfig = {"callbacks": [tracer]}
assert [ assert [
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config) _
async for _ in RunnableLambda[str, str](lambda _: llm).astream(
"", config=config
)
] == list(llm_res) ] == list(llm_res)
assert len(tracer.runs) == 1 assert len(tracer.runs) == 1
@ -5300,7 +5300,7 @@ async def test_ainvoke_on_returned_runnable() -> None:
def func(_input: dict, /) -> Runnable: def func(_input: dict, /) -> Runnable:
return idchain return idchain
assert await RunnableLambda(func).ainvoke({}) assert await RunnableLambda[dict, bool](func).ainvoke({})
def test_invoke_stream_passthrough_assign_trace() -> None: def test_invoke_stream_passthrough_assign_trace() -> None:

View File

@ -210,7 +210,7 @@ def test_decorator_with_specified_schema() -> None:
return f"{arg1} {arg2} {arg3}" return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func_v1, BaseTool) assert isinstance(tool_func_v1, BaseTool)
assert tool_func_v1.args_schema == _MockSchemaV1 assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)
def test_decorated_function_schema_equivalent() -> None: def test_decorated_function_schema_equivalent() -> None:

View File

@ -15,7 +15,7 @@ from langchain_core.utils.aiter import abatch_iterate
], ],
) )
async def test_abatch_iterate( async def test_abatch_iterate(
input_size: int, input_iterable: list[str], expected_output: list[str] input_size: int, input_iterable: list[str], expected_output: list[list[str]]
) -> None: ) -> None:
"""Test batching function.""" """Test batching function."""

View File

@ -13,7 +13,7 @@ from langchain_core.utils.iter import batch_iterate
], ],
) )
def test_batch_iterate( def test_batch_iterate(
input_size: int, input_iterable: list[str], expected_output: list[str] input_size: int, input_iterable: list[str], expected_output: list[list[str]]
) -> None: ) -> None:
"""Test batching function.""" """Test batching function."""
assert list(batch_iterate(input_size, input_iterable)) == expected_output assert list(batch_iterate(input_size, input_iterable)) == expected_output

View File

@ -1,5 +1,7 @@
"""Tests for verifying that testing utility code works as expected.""" """Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle from itertools import cycle
from typing import Any, Optional, Union from typing import Any, Optional, Union
from uuid import UUID from uuid import UUID
@ -107,12 +109,7 @@ async def test_generic_fake_chat_model_stream() -> None:
), ),
] ]
accumulate_chunks = None accumulate_chunks = reduce(operator.add, chunks)
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
assert accumulate_chunks == AIMessageChunk( assert accumulate_chunks == AIMessageChunk(
id="a1", id="a1",