core: Add mypy strict-equality rule (#31286)

This commit is contained in:
Christophe Bornet 2025-06-02 20:24:35 +02:00 committed by GitHub
parent 2c4e0ab3bc
commit 539e5b6936
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 43 additions and 55 deletions

View File

@ -51,7 +51,6 @@ from langchain_core.messages import (
AIMessage,
AnyMessage,
BaseMessage,
BaseMessageChunk,
HumanMessage,
convert_to_messages,
convert_to_openai_image_block,
@ -446,13 +445,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> Iterator[BaseMessageChunk]:
) -> Iterator[BaseMessage]:
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
# model doesn't implement streaming, so use default implementation
yield cast(
"BaseMessageChunk",
self.invoke(input, config=config, stop=stop, **kwargs),
)
yield self.invoke(input, config=config, stop=stop, **kwargs)
else:
config = ensure_config(config)
messages = self._convert_input(input).to_messages()
@ -537,13 +533,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*,
stop: Optional[list[str]] = None,
**kwargs: Any,
) -> AsyncIterator[BaseMessageChunk]:
) -> AsyncIterator[BaseMessage]:
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
# No async or sync stream is implemented, so fall back to ainvoke
yield cast(
"BaseMessageChunk",
await self.ainvoke(input, config=config, stop=stop, **kwargs),
)
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return
config = ensure_config(config)
@ -1454,7 +1447,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
PydanticToolsParser,
)
if self.bind_tools is BaseChatModel.bind_tools:
if type(self).bind_tools is BaseChatModel.bind_tools:
msg = "with_structured_output is not implemented for this model."
raise NotImplementedError(msg)

View File

@ -4331,8 +4331,9 @@ class RunnableLambda(Runnable[Input, Output]):
self,
func: Union[
Union[
Callable[[Input], Output],
Callable[[Input], Iterator[Output]],
Callable[[Input], Runnable[Input, Output]],
Callable[[Input], Output],
Callable[[Input, RunnableConfig], Output],
Callable[[Input, CallbackManagerForChainRun], Output],
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],

View File

@ -77,6 +77,7 @@ if IS_PYDANTIC_V1:
TypeBaseModel = type[BaseModel]
elif IS_PYDANTIC_V2:
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
from pydantic.v1.fields import ModelField
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
@ -373,20 +374,20 @@ if IS_PYDANTIC_V2:
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
@overload
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
@overload
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
def get_fields(
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
"""Get the field names of a Pydantic model."""
if hasattr(model, "model_fields"):
return model.model_fields
if hasattr(model, "__fields__"):
return model.__fields__ # type: ignore[return-value]
return model.__fields__
msg = f"Expected a Pydantic model. Got {type(model)}"
raise TypeError(msg)

View File

@ -74,7 +74,6 @@ report_deprecated_as_note = "True"
# TODO: activate for 'strict' checking
disallow_any_generics = "False"
warn_return_any = "False"
strict_equality = "False"
[tool.ruff]

View File

@ -1,5 +1,7 @@
"""Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None:
]
assert len({chunk.id for chunk in chunks}) == 1
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
accumulate_chunks = reduce(operator.add, chunks)
assert accumulate_chunks == AIMessageChunk(
content="",

View File

@ -1,7 +1,7 @@
"""Test PydanticOutputParser."""
from enum import Enum
from typing import Literal, Optional
from typing import Literal, Optional, Union
import pydantic
import pytest
@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel):
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
def test_pydantic_parser_chaining(
pydantic_object: TBaseModel,
pydantic_object: Union[type[ForecastV2], type[ForecastV1]],
) -> None:
prompt = PromptTemplate(
template="""{{
@ -43,11 +43,11 @@ def test_pydantic_parser_chaining(
model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
chain = prompt | model | parser
res = chain.invoke({})
assert type(res) is pydantic_object
assert isinstance(res, pydantic_object)
assert res.f_or_c == "C"
assert res.temperature == 20
assert res.forecast == "Sunny"

View File

@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None:
template = " {{''.__class__.__bases__[0] }} " # malicious code
prompt = PromptTemplate.from_template(template, template_format="jinja2")
with pytest.raises(jinja2.exceptions.SecurityError):
assert prompt.format() == []
prompt.format()
@pytest.mark.requires("jinja2")

View File

@ -51,7 +51,7 @@ def test_structured_prompt_pydantic() -> None:
chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42)
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap]
def test_structured_prompt_dict() -> None:
@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None:
chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
def test_structured_prompt_kwargs() -> None:
@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None:
)
model = FakeStructuredChatModel(responses=[])
chain = prompt | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
chain = loads(dumps(prompt)) | model
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
class OutputSchema(BaseModel):
name: str
@ -116,7 +116,7 @@ def test_structured_prompt_kwargs() -> None:
chain = prompt | model
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap]
def test_structured_prompt_template_format() -> None:

View File

@ -671,7 +671,7 @@ def test_with_types_with_type_generics() -> None:
def test_schema_with_itemgetter() -> None:
"""Test runnable with itemgetter."""
foo = RunnableLambda(itemgetter("hello"))
foo: Runnable = RunnableLambda(itemgetter("hello"))
assert _schema(foo.input_schema) == {
"properties": {"hello": {"title": "Hello"}},
"required": ["hello"],
@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None:
# sleep to better simulate a real stream
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
output = list(RunnableLambda(lambda _: llm).stream(""))
output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
assert output == list(llm_res)
@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
config: RunnableConfig = {"callbacks": [tracer]}
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
llm_res
)
assert list(
RunnableLambda[str, str](lambda _: llm).stream("", config=config)
) == list(llm_res)
assert len(tracer.runs) == 1
assert tracer.runs[0].error is None
@ -4075,10 +4075,7 @@ async def test_runnable_lambda_astream() -> None:
assert output == list(llm_res)
output = [
chunk
async for chunk in cast(
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
)
chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
]
assert output == list(llm_res)
@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
config: RunnableConfig = {"callbacks": [tracer]}
assert [
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
_
async for _ in RunnableLambda[str, str](lambda _: llm).astream(
"", config=config
)
] == list(llm_res)
assert len(tracer.runs) == 1
@ -5300,7 +5300,7 @@ async def test_ainvoke_on_returned_runnable() -> None:
def func(_input: dict, /) -> Runnable:
return idchain
assert await RunnableLambda(func).ainvoke({})
assert await RunnableLambda[dict, bool](func).ainvoke({})
def test_invoke_stream_passthrough_assign_trace() -> None:

View File

@ -210,7 +210,7 @@ def test_decorator_with_specified_schema() -> None:
return f"{arg1} {arg2} {arg3}"
assert isinstance(tool_func_v1, BaseTool)
assert tool_func_v1.args_schema == _MockSchemaV1
assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)
def test_decorated_function_schema_equivalent() -> None:

View File

@ -15,7 +15,7 @@ from langchain_core.utils.aiter import abatch_iterate
],
)
async def test_abatch_iterate(
input_size: int, input_iterable: list[str], expected_output: list[str]
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
) -> None:
"""Test batching function."""

View File

@ -13,7 +13,7 @@ from langchain_core.utils.iter import batch_iterate
],
)
def test_batch_iterate(
input_size: int, input_iterable: list[str], expected_output: list[str]
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
) -> None:
"""Test batching function."""
assert list(batch_iterate(input_size, input_iterable)) == expected_output

View File

@ -1,5 +1,7 @@
"""Tests for verifying that testing utility code works as expected."""
import operator
from functools import reduce
from itertools import cycle
from typing import Any, Optional, Union
from uuid import UUID
@ -107,12 +109,7 @@ async def test_generic_fake_chat_model_stream() -> None:
),
]
accumulate_chunks = None
for chunk in chunks:
if accumulate_chunks is None:
accumulate_chunks = chunk
else:
accumulate_chunks += chunk
accumulate_chunks = reduce(operator.add, chunks)
assert accumulate_chunks == AIMessageChunk(
id="a1",