diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 1070a8877e4..5b7165e83a0 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -51,7 +51,6 @@ from langchain_core.messages import ( AIMessage, AnyMessage, BaseMessage, - BaseMessageChunk, HumanMessage, convert_to_messages, convert_to_openai_image_block, @@ -446,13 +445,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> Iterator[BaseMessageChunk]: + ) -> Iterator[BaseMessage]: if not self._should_stream(async_api=False, **{**kwargs, "stream": True}): # model doesn't implement streaming, so use default implementation - yield cast( - "BaseMessageChunk", - self.invoke(input, config=config, stop=stop, **kwargs), - ) + yield self.invoke(input, config=config, stop=stop, **kwargs) else: config = ensure_config(config) messages = self._convert_input(input).to_messages() @@ -537,13 +533,10 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *, stop: Optional[list[str]] = None, **kwargs: Any, - ) -> AsyncIterator[BaseMessageChunk]: + ) -> AsyncIterator[BaseMessage]: if not self._should_stream(async_api=True, **{**kwargs, "stream": True}): # No async or sync stream is implemented, so fall back to ainvoke - yield cast( - "BaseMessageChunk", - await self.ainvoke(input, config=config, stop=stop, **kwargs), - ) + yield await self.ainvoke(input, config=config, stop=stop, **kwargs) return config = ensure_config(config) @@ -1454,7 +1447,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): PydanticToolsParser, ) - if self.bind_tools is BaseChatModel.bind_tools: + if type(self).bind_tools is BaseChatModel.bind_tools: msg = "with_structured_output is not implemented for this model." raise NotImplementedError(msg) diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index b408ef8f116..2a044c52a13 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -4331,8 +4331,9 @@ class RunnableLambda(Runnable[Input, Output]): self, func: Union[ Union[ - Callable[[Input], Output], Callable[[Input], Iterator[Output]], + Callable[[Input], Runnable[Input, Output]], + Callable[[Input], Output], Callable[[Input, RunnableConfig], Output], Callable[[Input, CallbackManagerForChainRun], Output], Callable[[Input, CallbackManagerForChainRun, RunnableConfig], Output], diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index ea987741ce3..cab3d4da6c4 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -77,6 +77,7 @@ if IS_PYDANTIC_V1: TypeBaseModel = type[BaseModel] elif IS_PYDANTIC_V2: from pydantic.v1.fields import FieldInfo as FieldInfoV1 # type: ignore[assignment] + from pydantic.v1.fields import ModelField # Union type needs to be last assignment to PydanticBaseModel to make mypy happy. PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore[assignment,misc] @@ -373,20 +374,20 @@ if IS_PYDANTIC_V2: def get_fields(model: BaseModelV2) -> dict[str, FieldInfoV2]: ... @overload - def get_fields(model: type[BaseModelV1]) -> dict[str, FieldInfoV1]: ... + def get_fields(model: type[BaseModelV1]) -> dict[str, ModelField]: ... @overload - def get_fields(model: BaseModelV1) -> dict[str, FieldInfoV1]: ... + def get_fields(model: BaseModelV1) -> dict[str, ModelField]: ... def get_fields( model: Union[type[Union[BaseModelV2, BaseModelV1]], BaseModelV2, BaseModelV1], - ) -> Union[dict[str, FieldInfoV2], dict[str, FieldInfoV1]]: + ) -> Union[dict[str, FieldInfoV2], dict[str, ModelField]]: """Get the field names of a Pydantic model.""" if hasattr(model, "model_fields"): return model.model_fields if hasattr(model, "__fields__"): - return model.__fields__ # type: ignore[return-value] + return model.__fields__ msg = f"Expected a Pydantic model. Got {type(model)}" raise TypeError(msg) diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 9637202b901..b31ee92cd83 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -74,7 +74,6 @@ report_deprecated_as_note = "True" # TODO: activate for 'strict' checking disallow_any_generics = "False" warn_return_any = "False" -strict_equality = "False" [tool.ruff] diff --git a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py index 7500e1640ac..c39440f5e32 100644 --- a/libs/core/tests/unit_tests/fake/test_fake_chat_model.py +++ b/libs/core/tests/unit_tests/fake/test_fake_chat_model.py @@ -1,5 +1,7 @@ """Tests for verifying that testing utility code works as expected.""" +import operator +from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -115,12 +117,7 @@ async def test_generic_fake_chat_model_stream() -> None: ] assert len({chunk.id for chunk in chunks}) == 1 - accumulate_chunks = None - for chunk in chunks: - if accumulate_chunks is None: - accumulate_chunks = chunk - else: - accumulate_chunks += chunk + accumulate_chunks = reduce(operator.add, chunks) assert accumulate_chunks == AIMessageChunk( content="", diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 07d192c1763..0486878749a 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -1,7 +1,7 @@ """Test PydanticOutputParser.""" from enum import Enum -from typing import Literal, Optional +from typing import Literal, Optional, Union import pydantic import pytest @@ -30,7 +30,7 @@ class ForecastV1(V1BaseModel): @pytest.mark.parametrize("pydantic_object", [ForecastV2, ForecastV1]) def test_pydantic_parser_chaining( - pydantic_object: TBaseModel, + pydantic_object: Union[type[ForecastV2], type[ForecastV1]], ) -> None: prompt = PromptTemplate( template="""{{ @@ -43,11 +43,11 @@ def test_pydantic_parser_chaining( model = ParrotFakeChatModel() - parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] + parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var] chain = prompt | model | parser res = chain.invoke({}) - assert type(res) is pydantic_object + assert isinstance(res, pydantic_object) assert res.f_or_c == "C" assert res.temperature == 20 assert res.forecast == "Sunny" diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index 08a7b81e871..5d1ddd3293c 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -441,7 +441,7 @@ def test_basic_sandboxing_with_jinja2() -> None: template = " {{''.__class__.__bases__[0] }} " # malicious code prompt = PromptTemplate.from_template(template, template_format="jinja2") with pytest.raises(jinja2.exceptions.SecurityError): - assert prompt.format() == [] + prompt.format() @pytest.mark.requires("jinja2") diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 8f04f2029ba..0b74b37cc6f 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -51,7 +51,7 @@ def test_structured_prompt_pydantic() -> None: chain = prompt | model - assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) + assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) # type: ignore[comparison-overlap] def test_structured_prompt_dict() -> None: @@ -73,13 +73,13 @@ def test_structured_prompt_dict() -> None: chain = prompt | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap] assert loads(dumps(prompt)).model_dump() == prompt.model_dump() chain = loads(dumps(prompt)) | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 42} # type: ignore[comparison-overlap] def test_structured_prompt_kwargs() -> None: @@ -99,10 +99,10 @@ def test_structured_prompt_kwargs() -> None: ) model = FakeStructuredChatModel(responses=[]) chain = prompt | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap] assert loads(dumps(prompt)).model_dump() == prompt.model_dump() chain = loads(dumps(prompt)) | model - assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} + assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 7} # type: ignore[comparison-overlap] class OutputSchema(BaseModel): name: str @@ -116,7 +116,7 @@ def test_structured_prompt_kwargs() -> None: chain = prompt | model - assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) + assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=7) # type: ignore[comparison-overlap] def test_structured_prompt_template_format() -> None: diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index a476975ac7c..db7e898dc99 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -671,7 +671,7 @@ def test_with_types_with_type_generics() -> None: def test_schema_with_itemgetter() -> None: """Test runnable with itemgetter.""" - foo = RunnableLambda(itemgetter("hello")) + foo: Runnable = RunnableLambda(itemgetter("hello")) assert _schema(foo.input_schema) == { "properties": {"hello": {"title": "Hello"}}, "required": ["hello"], @@ -4001,7 +4001,7 @@ def test_runnable_lambda_stream() -> None: # sleep to better simulate a real stream llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) - output = list(RunnableLambda(lambda _: llm).stream("")) + output = list(RunnableLambda[str, str](lambda _: llm).stream("")) assert output == list(llm_res) @@ -4014,9 +4014,9 @@ def test_runnable_lambda_stream_with_callbacks() -> None: llm = FakeStreamingListLLM(responses=[llm_res], sleep=0.01) config: RunnableConfig = {"callbacks": [tracer]} - assert list(RunnableLambda(lambda _: llm).stream("", config=config)) == list( - llm_res - ) + assert list( + RunnableLambda[str, str](lambda _: llm).stream("", config=config) + ) == list(llm_res) assert len(tracer.runs) == 1 assert tracer.runs[0].error is None @@ -4075,10 +4075,7 @@ async def test_runnable_lambda_astream() -> None: assert output == list(llm_res) output = [ - chunk - async for chunk in cast( - "AsyncIterator[str]", RunnableLambda(lambda _: llm).astream("") - ) + chunk async for chunk in RunnableLambda[str, str](lambda _: llm).astream("") ] assert output == list(llm_res) @@ -4093,7 +4090,10 @@ async def test_runnable_lambda_astream_with_callbacks() -> None: config: RunnableConfig = {"callbacks": [tracer]} assert [ - _ async for _ in RunnableLambda(lambda _: llm).astream("", config=config) + _ + async for _ in RunnableLambda[str, str](lambda _: llm).astream( + "", config=config + ) ] == list(llm_res) assert len(tracer.runs) == 1 @@ -5300,7 +5300,7 @@ async def test_ainvoke_on_returned_runnable() -> None: def func(_input: dict, /) -> Runnable: return idchain - assert await RunnableLambda(func).ainvoke({}) + assert await RunnableLambda[dict, bool](func).ainvoke({}) def test_invoke_stream_passthrough_assign_trace() -> None: diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index ac18073dde5..4819634c9e5 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -210,7 +210,7 @@ def test_decorator_with_specified_schema() -> None: return f"{arg1} {arg2} {arg3}" assert isinstance(tool_func_v1, BaseTool) - assert tool_func_v1.args_schema == _MockSchemaV1 + assert tool_func_v1.args_schema == cast("ArgsSchema", _MockSchemaV1) def test_decorated_function_schema_equivalent() -> None: diff --git a/libs/core/tests/unit_tests/utils/test_aiter.py b/libs/core/tests/unit_tests/utils/test_aiter.py index 30c4c6ea59a..078fbdf8ac3 100644 --- a/libs/core/tests/unit_tests/utils/test_aiter.py +++ b/libs/core/tests/unit_tests/utils/test_aiter.py @@ -15,7 +15,7 @@ from langchain_core.utils.aiter import abatch_iterate ], ) async def test_abatch_iterate( - input_size: int, input_iterable: list[str], expected_output: list[str] + input_size: int, input_iterable: list[str], expected_output: list[list[str]] ) -> None: """Test batching function.""" diff --git a/libs/core/tests/unit_tests/utils/test_iter.py b/libs/core/tests/unit_tests/utils/test_iter.py index 0cb3fc66cc5..84e7882d48e 100644 --- a/libs/core/tests/unit_tests/utils/test_iter.py +++ b/libs/core/tests/unit_tests/utils/test_iter.py @@ -13,7 +13,7 @@ from langchain_core.utils.iter import batch_iterate ], ) def test_batch_iterate( - input_size: int, input_iterable: list[str], expected_output: list[str] + input_size: int, input_iterable: list[str], expected_output: list[list[str]] ) -> None: """Test batching function.""" assert list(batch_iterate(input_size, input_iterable)) == expected_output diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index a403e3d027f..ed8d47a71f2 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -1,5 +1,7 @@ """Tests for verifying that testing utility code works as expected.""" +import operator +from functools import reduce from itertools import cycle from typing import Any, Optional, Union from uuid import UUID @@ -107,12 +109,7 @@ async def test_generic_fake_chat_model_stream() -> None: ), ] - accumulate_chunks = None - for chunk in chunks: - if accumulate_chunks is None: - accumulate_chunks = chunk - else: - accumulate_chunks += chunk + accumulate_chunks = reduce(operator.add, chunks) assert accumulate_chunks == AIMessageChunk( id="a1",