mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-17 08:29:28 +00:00
core: Add mypy strict-equality rule (#31286)
This commit is contained in:
parent
2c4e0ab3bc
commit
539e5b6936
@ -51,7 +51,6 @@ from langchain_core.messages import (
|
|||||||
AIMessage,
|
AIMessage,
|
||||||
AnyMessage,
|
AnyMessage,
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
BaseMessageChunk,
|
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
convert_to_messages,
|
convert_to_messages,
|
||||||
convert_to_openai_image_block,
|
convert_to_openai_image_block,
|
||||||
@ -446,13 +445,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[BaseMessageChunk]:
|
) -> Iterator[BaseMessage]:
|
||||||
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
if not self._should_stream(async_api=False, **{**kwargs, "stream": True}):
|
||||||
# model doesn't implement streaming, so use default implementation
|
# model doesn't implement streaming, so use default implementation
|
||||||
yield cast(
|
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
||||||
"BaseMessageChunk",
|
|
||||||
self.invoke(input, config=config, stop=stop, **kwargs),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
messages = self._convert_input(input).to_messages()
|
messages = self._convert_input(input).to_messages()
|
||||||
@ -537,13 +533,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
*,
|
*,
|
||||||
stop: Optional[list[str]] = None,
|
stop: Optional[list[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[BaseMessageChunk]:
|
) -> AsyncIterator[BaseMessage]:
|
||||||
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
if not self._should_stream(async_api=True, **{**kwargs, "stream": True}):
|
||||||
# No async or sync stream is implemented, so fall back to ainvoke
|
# No async or sync stream is implemented, so fall back to ainvoke
|
||||||
yield cast(
|
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
||||||
"BaseMessageChunk",
|
|
||||||
await self.ainvoke(input, config=config, stop=stop, **kwargs),
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
config = ensure_config(config)
|
config = ensure_config(config)
|
||||||
@ -1454,7 +1447,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
|||||||
PydanticToolsParser,
|
PydanticToolsParser,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.bind_tools is BaseChatModel.bind_tools:
|
if type(self).bind_tools is BaseChatModel.bind_tools:
|
||||||
msg = "with_structured_output is not implemented for this model."
|
msg = "with_structured_output is not implemented for this model."
|
||||||
raise NotImplementedError(msg)
|
raise NotImplementedError(msg)
|
||||||
|
|
||||||
|
@ -4331,8 +4331,9 @@ class RunnableLambda(Runnable[Input, Output]):
|
|||||||
self,
|
self,
|
||||||
func: Union[
|
func: Union[
|
||||||
Union[
|
Union[
|
||||||
Callable[[Input], Output],
|
|
||||||
Callable[[Input], Iterator[Output]],
|
Callable[[Input], Iterator[Output]],
|
||||||
|
Callable[[Input], Runnable[Input, Output]],
|
||||||
|
Callable[[Input], Output],
|
||||||
Callable[[Input, RunnableConfig], Output],
|
Callable[[Input, RunnableConfig], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun], Output],
|
Callable[[Input, CallbackManagerForChainRun], Output],
|
||||||
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output],
|
||||||
|
@ -77,6 +77,7 @@ if IS_PYDANTIC_V1:
|
|||||||
TypeBaseModel = type[BaseModel]
|
TypeBaseModel = type[BaseModel]
|
||||||
elif IS_PYDANTIC_V2:
|
elif IS_PYDANTIC_V2:
|
||||||
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment]
|
||||||
|
from pydantic.v1.fields import ModelField
|
||||||
|
|
||||||
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
|
||||||
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc]
|
||||||
@ -373,20 +374,20 @@ if IS_PYDANTIC_V2:
|
|||||||
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ...
|
def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ...
|
def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ...
|
||||||
|
|
||||||
def get_fields(
|
def get_fields(
|
||||||
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1],
|
||||||
) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]:
|
) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]:
|
||||||
"""Get the field names of a Pydantic model."""
|
"""Get the field names of a Pydantic model."""
|
||||||
if hasattr(model, "model_fields"):
|
if hasattr(model, "model_fields"):
|
||||||
return model.model_fields
|
return model.model_fields
|
||||||
|
|
||||||
if hasattr(model, "__fields__"):
|
if hasattr(model, "__fields__"):
|
||||||
return model.__fields__ # type: ignore[return-value]
|
return model.__fields__
|
||||||
msg = f"Expected a Pydantic model. Got {type(model)}"
|
msg = f"Expected a Pydantic model. Got {type(model)}"
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
|
|
||||||
|
@ -74,7 +74,6 @@ report_deprecated_as_note = "True"
|
|||||||
# TODO: activate for 'strict' checking
|
# TODO: activate for 'strict' checking
|
||||||
disallow_any_generics = "False"
|
disallow_any_generics = "False"
|
||||||
warn_return_any = "False"
|
warn_return_any = "False"
|
||||||
strict_equality = "False"
|
|
||||||
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Tests for verifying that testing utility code works as expected."""
|
"""Tests for verifying that testing utility code works as expected."""
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
]
|
]
|
||||||
assert len({chunk.id for chunk in chunks}) == 1
|
assert len({chunk.id for chunk in chunks}) == 1
|
||||||
|
|
||||||
accumulate_chunks = None
|
accumulate_chunks = reduce(operator.add, chunks)
|
||||||
for chunk in chunks:
|
|
||||||
if accumulate_chunks is None:
|
|
||||||
accumulate_chunks = chunk
|
|
||||||
else:
|
|
||||||
accumulate_chunks += chunk
|
|
||||||
|
|
||||||
assert accumulate_chunks == AIMessageChunk(
|
assert accumulate_chunks == AIMessageChunk(
|
||||||
content="",
|
content="",
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Test PydanticOutputParser."""
|
"""Test PydanticOutputParser."""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
import pytest
|
import pytest
|
||||||
@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
@pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1])
|
||||||
def test_pydantic_parser_chaining(
|
def test_pydantic_parser_chaining(
|
||||||
pydantic_object: TBaseModel,
|
pydantic_object: Union[type[ForecastV2], type[ForecastV1]],
|
||||||
) -> None:
|
) -> None:
|
||||||
prompt = PromptTemplate(
|
prompt = PromptTemplate(
|
||||||
template="""{{
|
template="""{{
|
||||||
@ -43,11 +43,11 @@ def test_pydantic_parser_chaining(
|
|||||||
|
|
||||||
model = ParrotFakeChatModel()
|
model = ParrotFakeChatModel()
|
||||||
|
|
||||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
|
||||||
chain = prompt | model | parser
|
chain = prompt | model | parser
|
||||||
|
|
||||||
res = chain.invoke({})
|
res = chain.invoke({})
|
||||||
assert type(res) is pydantic_object
|
assert isinstance(res, pydantic_object)
|
||||||
assert res.f_or_c == "C"
|
assert res.f_or_c == "C"
|
||||||
assert res.temperature == 20
|
assert res.temperature == 20
|
||||||
assert res.forecast == "Sunny"
|
assert res.forecast == "Sunny"
|
||||||
|
@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None:
|
|||||||
template = " {{''.__class__.__bases__[0] }} " # malicious code
|
template = " {{''.__class__.__bases__[0] }} " # malicious code
|
||||||
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
prompt = PromptTemplate.from_template(template, template_format="jinja2")
|
||||||
with pytest.raises(jinja2.exceptions.SecurityError):
|
with pytest.raises(jinja2.exceptions.SecurityError):
|
||||||
assert prompt.format() == []
|
prompt.format()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("jinja2")
|
@pytest.mark.requires("jinja2")
|
||||||
|
@ -51,7 +51,7 @@ def test_structured_prompt_pydantic() -> None:
|
|||||||
|
|
||||||
chain = prompt | model
|
chain = prompt | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42)
|
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_prompt_dict() -> None:
|
def test_structured_prompt_dict() -> None:
|
||||||
@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None:
|
|||||||
|
|
||||||
chain = prompt | model
|
chain = prompt | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
||||||
|
|
||||||
chain = loads(dumps(prompt)) | model
|
chain = loads(dumps(prompt)) | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_prompt_kwargs() -> None:
|
def test_structured_prompt_kwargs() -> None:
|
||||||
@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None:
|
|||||||
)
|
)
|
||||||
model = FakeStructuredChatModel(responses=[])
|
model = FakeStructuredChatModel(responses=[])
|
||||||
chain = prompt | model
|
chain = prompt | model
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
|
||||||
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
assert loads(dumps(prompt)).model_dump() == prompt.model_dump()
|
||||||
chain = loads(dumps(prompt)) | model
|
chain = loads(dumps(prompt)) | model
|
||||||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7}
|
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
class OutputSchema(BaseModel):
|
class OutputSchema(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -116,7 +116,7 @@ def test_structured_prompt_kwargs() -> None:
|
|||||||
|
|
||||||
chain = prompt | model
|
chain = prompt | model
|
||||||
|
|
||||||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7)
|
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap]
|
||||||
|
|
||||||
|
|
||||||
def test_structured_prompt_template_format() -> None:
|
def test_structured_prompt_template_format() -> None:
|
||||||
|
@ -671,7 +671,7 @@ def test_with_types_with_type_generics() -> None:
|
|||||||
|
|
||||||
def test_schema_with_itemgetter() -> None:
|
def test_schema_with_itemgetter() -> None:
|
||||||
"""Test runnable with itemgetter."""
|
"""Test runnable with itemgetter."""
|
||||||
foo = RunnableLambda(itemgetter("hello"))
|
foo: Runnable = RunnableLambda(itemgetter("hello"))
|
||||||
assert _schema(foo.input_schema) == {
|
assert _schema(foo.input_schema) == {
|
||||||
"properties": {"hello": {"title": "Hello"}},
|
"properties": {"hello": {"title": "Hello"}},
|
||||||
"required": ["hello"],
|
"required": ["hello"],
|
||||||
@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None:
|
|||||||
# sleep to better simulate a real stream
|
# sleep to better simulate a real stream
|
||||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
|
|
||||||
output = list(RunnableLambda(lambda _: llm).stream(""))
|
output = list(RunnableLambda[str, str](lambda _: llm).stream(""))
|
||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
|
|
||||||
|
|
||||||
@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None:
|
|||||||
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01)
|
||||||
config: RunnableConfig = {"callbacks": [tracer]}
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list(
|
assert list(
|
||||||
llm_res
|
RunnableLambda[str, str](lambda _: llm).stream("", config=config)
|
||||||
)
|
) == list(llm_res)
|
||||||
|
|
||||||
assert len(tracer.runs) == 1
|
assert len(tracer.runs) == 1
|
||||||
assert tracer.runs[0].error is None
|
assert tracer.runs[0].error is None
|
||||||
@ -4075,10 +4075,7 @@ async def test_runnable_lambda_astream() -> None:
|
|||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
|
|
||||||
output = [
|
output = [
|
||||||
chunk
|
chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("")
|
||||||
async for chunk in cast(
|
|
||||||
"AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("")
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
assert output == list(llm_res)
|
assert output == list(llm_res)
|
||||||
|
|
||||||
@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None:
|
|||||||
config: RunnableConfig = {"callbacks": [tracer]}
|
config: RunnableConfig = {"callbacks": [tracer]}
|
||||||
|
|
||||||
assert [
|
assert [
|
||||||
_ async for _ in RunnableLambda(lambda _: llm).astream("", config=config)
|
_
|
||||||
|
async for _ in RunnableLambda[str, str](lambda _: llm).astream(
|
||||||
|
"", config=config
|
||||||
|
)
|
||||||
] == list(llm_res)
|
] == list(llm_res)
|
||||||
|
|
||||||
assert len(tracer.runs) == 1
|
assert len(tracer.runs) == 1
|
||||||
@ -5300,7 +5300,7 @@ async def test_ainvoke_on_returned_runnable() -> None:
|
|||||||
def func(_input: dict, /) -> Runnable:
|
def func(_input: dict, /) -> Runnable:
|
||||||
return idchain
|
return idchain
|
||||||
|
|
||||||
assert await RunnableLambda(func).ainvoke({})
|
assert await RunnableLambda[dict, bool](func).ainvoke({})
|
||||||
|
|
||||||
|
|
||||||
def test_invoke_stream_passthrough_assign_trace() -> None:
|
def test_invoke_stream_passthrough_assign_trace() -> None:
|
||||||
|
@ -210,7 +210,7 @@ def test_decorator_with_specified_schema() -> None:
|
|||||||
return f"{arg1} {arg2} {arg3}"
|
return f"{arg1} {arg2} {arg3}"
|
||||||
|
|
||||||
assert isinstance(tool_func_v1, BaseTool)
|
assert isinstance(tool_func_v1, BaseTool)
|
||||||
assert tool_func_v1.args_schema == _MockSchemaV1
|
assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1)
|
||||||
|
|
||||||
|
|
||||||
def test_decorated_function_schema_equivalent() -> None:
|
def test_decorated_function_schema_equivalent() -> None:
|
||||||
|
@ -15,7 +15,7 @@ from langchain_core.utils.aiter import abatch_iterate
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_abatch_iterate(
|
async def test_abatch_iterate(
|
||||||
input_size: int, input_iterable: list[str], expected_output: list[str]
|
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test batching function."""
|
"""Test batching function."""
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ from langchain_core.utils.iter import batch_iterate
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_batch_iterate(
|
def test_batch_iterate(
|
||||||
input_size: int, input_iterable: list[str], expected_output: list[str]
|
input_size: int, input_iterable: list[str], expected_output: list[list[str]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test batching function."""
|
"""Test batching function."""
|
||||||
assert list(batch_iterate(input_size, input_iterable)) == expected_output
|
assert list(batch_iterate(input_size, input_iterable)) == expected_output
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Tests for verifying that testing utility code works as expected."""
|
"""Tests for verifying that testing utility code works as expected."""
|
||||||
|
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
from itertools import cycle
|
from itertools import cycle
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@ -107,12 +109,7 @@ async def test_generic_fake_chat_model_stream() -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
accumulate_chunks = None
|
accumulate_chunks = reduce(operator.add, chunks)
|
||||||
for chunk in chunks:
|
|
||||||
if accumulate_chunks is None:
|
|
||||||
accumulate_chunks = chunk
|
|
||||||
else:
|
|
||||||
accumulate_chunks += chunk
|
|
||||||
|
|
||||||
assert accumulate_chunks == AIMessageChunk(
|
assert accumulate_chunks == AIMessageChunk(
|
||||||
id="a1",
|
id="a1",
|
||||||
|
Loading…
Reference in New Issue
Block a user