mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 09:04:03 +00:00
core: Add mypy strict-equality rule (#31286)
This commit is contained in:
parent
2c4e0ab3bc
commit
539e5b6936
@ -51,7 +51,6 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AnyMessage,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
HumanMessage,
|
||||
convert_to_messages,
|
||||
convert_to_openai_image_block,
|
||||
@ -446,13 +445,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[BaseMessageChunk]:
|
||||
) -> Iterator[BaseMessage]:
|
||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||
# model doesn't implement streaming, so use default implementation
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
self.invoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||
else:
|
||||
config = ensure_config(config)
|
||||
messages = self._convert_input(input).to_messages()
|
||||
@ -537,13 +533,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
*,
|
||||
stop: Optional[list[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[BaseMessageChunk]:
|
||||
) -> AsyncIterator[BaseMessage]:
|
||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||
# No async or sync stream is implemented, so fall back to ainvoke
|
||||
yield cast(
|
||||
"BaseMessageChunk",
|
||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
||||
)
|
||||
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||
return
|
||||
|
||||
config = ensure_config(config)
|
||||
@ -1454,7 +1447,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
PydanticToolsParser,
|
||||
)
|
||||
|
||||
if self.bind_tools is BaseChatModel.bind_tools:
|
||||
if type(self).bind_tools is BaseChatModel.bind_tools:
|
||||
msg = "with_structured_output is not implemented for this model."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
@ -4331,8 +4331,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
||||
self,
|
||||
func: Union[
|
||||
Union[
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input], Iterator[Output]],
|
||||
Callable[[Input], Runnable[Input, Output]],
|
||||
Callable[[Input], Output],
|
||||
Callable[[Input, RunnableConfig], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||
|
@ -77,6 +77,7 @@ if IS_PYDANTIC_V1:
|
||||
TypeBaseModel = type[BaseModel]
|
||||
elif IS_PYDANTIC_V2:
|
||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||
from pydantic.v1.fields import ModelField
|
||||
|
||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||
@ -373,20 +374,20 @@ if IS_PYDANTIC_V2:
|
||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
|
||||
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
||||
|
||||
@overload
|
||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
||||
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
||||
|
||||
def get_fields(
|
||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
||||
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
||||
"""Get the field names of a Pydantic model."""
|
||||
if hasattr(model, "model_fields"):
|
||||
return model.model_fields
|
||||
|
||||
if hasattr(model, "__fields__"):
|
||||
return model.__fields__ # type: ignore[return-value]
|
||||
return model.__fields__
|
||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||
raise TypeError(msg)
|
||||
|
||||
|
@ -74,7 +74,6 @@ report_deprecated_as_note = "True"
|
||||
# TODO: activate for 'strict' checking
|
||||
disallow_any_generics = "False"
|
||||
warn_return_any = "False"
|
||||
strict_equality = "False"
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
]
|
||||
assert len({chunk.id for chunk in chunks}) == 1
|
||||
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
accumulate_chunks = reduce(operator.add, chunks)
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
content="",
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Test PydanticOutputParser."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import pydantic
|
||||
import pytest
|
||||
@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel):
|
||||
|
||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||
def test_pydantic_parser_chaining(
|
||||
pydantic_object: TBaseModel,
|
||||
pydantic_object: Union[type[ForecastV2], type[ForecastV1]],
|
||||
) -> None:
|
||||
prompt = PromptTemplate(
|
||||
template="""{{
|
||||
@ -43,11 +43,11 @@ def test_pydantic_parser_chaining(
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
assert type(res) is pydantic_object
|
||||
assert isinstance(res, pydantic_object)
|
||||
assert res.f_or_c == "C"
|
||||
assert res.temperature == 20
|
||||
assert res.forecast == "Sunny"
|
||||
|
@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None:
|
||||
template = " {{''.__class__.__bases__[0] }} " # malicious code
|
||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||
with pytest.raises(jinja2.exceptions.SecurityError):
|
||||
assert prompt.format() == []
|
||||
prompt.format()
|
||||
|
||||
|
||||
@pytest.mark.requires("jinja2")
|
||||
|
@ -51,7 +51,7 @@ def test_structured_prompt_pydantic() -> None:
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42)
|
||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
def test_structured_prompt_dict() -> None:
|
||||
@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None:
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
|
||||
|
||||
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
||||
|
||||
chain = loads(dumps(prompt)) | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
def test_structured_prompt_kwargs() -> None:
|
||||
@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None:
|
||||
)
|
||||
model = FakeStructuredChatModel(responses=[])
|
||||
chain = prompt | model
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
|
||||
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
||||
chain = loads(dumps(prompt)) | model
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
|
||||
|
||||
class OutputSchema(BaseModel):
|
||||
name: str
|
||||
@ -116,7 +116,7 @@ def test_structured_prompt_kwargs() -> None:
|
||||
|
||||
chain = prompt | model
|
||||
|
||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
|
||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap]
|
||||
|
||||
|
||||
def test_structured_prompt_template_format() -> None:
|
||||
|
@ -671,7 +671,7 @@ def test_with_types_with_type_generics() -> None:
|
||||
|
||||
def test_schema_with_itemgetter() -> None:
|
||||
"""Test runnable with itemgetter."""
|
||||
foo = RunnableLambda(itemgetter("hello"))
|
||||
foo: Runnable = RunnableLambda(itemgetter("hello"))
|
||||
assert _schema(foo.input_schema) == {
|
||||
"properties": {"hello": {"title": "Hello"}},
|
||||
"required": ["hello"],
|
||||
@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None:
|
||||
# sleep to better simulate a real stream
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
|
||||
output = list(RunnableLambda(lambda _: llm).stream(""))
|
||||
output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
|
||||
assert output == list(llm_res)
|
||||
|
||||
|
||||
@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
|
||||
llm_res
|
||||
)
|
||||
assert list(
|
||||
RunnableLambda[str, str](lambda _: llm).stream("", config=config)
|
||||
) == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
assert tracer.runs[0].error is None
|
||||
@ -4075,10 +4075,7 @@ async def test_runnable_lambda_astream() -> None:
|
||||
assert output == list(llm_res)
|
||||
|
||||
output = [
|
||||
chunk
|
||||
async for chunk in cast(
|
||||
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
|
||||
)
|
||||
chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
|
||||
]
|
||||
assert output == list(llm_res)
|
||||
|
||||
@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
||||
config: RunnableConfig = {"callbacks": [tracer]}
|
||||
|
||||
assert [
|
||||
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
|
||||
_
|
||||
async for _ in RunnableLambda[str, str](lambda _: llm).astream(
|
||||
"", config=config
|
||||
)
|
||||
] == list(llm_res)
|
||||
|
||||
assert len(tracer.runs) == 1
|
||||
@ -5300,7 +5300,7 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
||||
def func(_input: dict, /) -> Runnable:
|
||||
return idchain
|
||||
|
||||
assert await RunnableLambda(func).ainvoke({})
|
||||
assert await RunnableLambda[dict, bool](func).ainvoke({})
|
||||
|
||||
|
||||
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||
|
@ -210,7 +210,7 @@ def test_decorator_with_specified_schema() -> None:
|
||||
return f"{arg1} {arg2} {arg3}"
|
||||
|
||||
assert isinstance(tool_func_v1, BaseTool)
|
||||
assert tool_func_v1.args_schema == _MockSchemaV1
|
||||
assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)
|
||||
|
||||
|
||||
def test_decorated_function_schema_equivalent() -> None:
|
||||
|
@ -15,7 +15,7 @@ from langchain_core.utils.aiter import abatch_iterate
|
||||
],
|
||||
)
|
||||
async def test_abatch_iterate(
|
||||
input_size: int, input_iterable: list[str], expected_output: list[str]
|
||||
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
|
||||
) -> None:
|
||||
"""Test batching function."""
|
||||
|
||||
|
@ -13,7 +13,7 @@ from langchain_core.utils.iter import batch_iterate
|
||||
],
|
||||
)
|
||||
def test_batch_iterate(
|
||||
input_size: int, input_iterable: list[str], expected_output: list[str]
|
||||
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
|
||||
) -> None:
|
||||
"""Test batching function."""
|
||||
assert list(batch_iterate(input_size, input_iterable)) == expected_output
|
||||
|
@ -1,5 +1,7 @@
|
||||
"""Tests for verifying that testing utility code works as expected."""
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import cycle
|
||||
from typing import Any, Optional, Union
|
||||
from uuid import UUID
|
||||
@ -107,12 +109,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
||||
),
|
||||
]
|
||||
|
||||
accumulate_chunks = None
|
||||
for chunk in chunks:
|
||||
if accumulate_chunks is None:
|
||||
accumulate_chunks = chunk
|
||||
else:
|
||||
accumulate_chunks += chunk
|
||||
accumulate_chunks = reduce(operator.add, chunks)
|
||||
|
||||
assert accumulate_chunks == AIMessageChunk(
|
||||
id="a1",
|
||||
|
Loading…
Reference in New Issue
Block a user