From 02d6b9106baa9f23f195abfd5d6d948d85eee29b Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 19 Aug 2025 15:39:53 +0200 Subject: [PATCH] chore(core): add mypy pydantic plugin (#32604) This helps to remove a bunch of mypy false positives. --- .../langchain_core/beta/runnables/context.py | 2 +- libs/core/langchain_core/documents/base.py | 2 +- .../language_models/chat_models.py | 18 +++++++++--------- libs/core/langchain_core/prompts/chat.py | 4 +--- libs/core/langchain_core/runnables/base.py | 10 +++++----- libs/core/langchain_core/runnables/branch.py | 2 +- libs/core/langchain_core/runnables/history.py | 2 +- .../langchain_core/runnables/passthrough.py | 6 +++--- libs/core/langchain_core/runnables/retry.py | 2 +- libs/core/langchain_core/runnables/router.py | 2 +- libs/core/langchain_core/structured_query.py | 10 +++------- libs/core/langchain_core/tools/structured.py | 2 +- libs/core/pyproject.toml | 1 + .../output_parsers/test_openai_tools.py | 2 +- .../output_parsers/test_pydantic_parser.py | 14 ++++++++------ .../tests/unit_tests/prompts/test_prompt.py | 2 +- .../unit_tests/runnables/test_configurable.py | 10 +++++----- libs/core/tests/unit_tests/utils/test_utils.py | 2 +- 18 files changed, 45 insertions(+), 48 deletions(-) diff --git a/libs/core/langchain_core/beta/runnables/context.py b/libs/core/langchain_core/beta/runnables/context.py index 9901913ab31..993b4665b2b 100644 --- a/libs/core/langchain_core/beta/runnables/context.py +++ b/libs/core/langchain_core/beta/runnables/context.py @@ -253,7 +253,7 @@ class ContextSet(RunnableSerializable): """ if key is not None: kwargs[key] = value - super().__init__( # type: ignore[call-arg] + super().__init__( keys={ k: _coerce_set_value(v) if v is not None else None for k, v in kwargs.items() diff --git a/libs/core/langchain_core/documents/base.py b/libs/core/langchain_core/documents/base.py index b22ee910bc0..601e831a959 100644 --- a/libs/core/langchain_core/documents/base.py +++ b/libs/core/langchain_core/documents/base.py @@ -277,7 +277,7 @@ class Document(BaseMedia): """Pass page_content in as positional or named arg.""" # my-py is complaining that page_content is not defined on the base class. # Here, we're relying on pydantic base class to handle the validation. - super().__init__(page_content=page_content, **kwargs) # type: ignore[call-arg] + super().__init__(page_content=page_content, **kwargs) @classmethod def is_lc_serializable(cls) -> bool: diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index 310f392fd25..5fcc760d8ab 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -533,7 +533,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generations = [generations_with_error_metadata] run_manager.on_llm_error( e, - response=LLMResult(generations=generations), # type: ignore[arg-type] + response=LLMResult(generations=generations), ) raise @@ -627,7 +627,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): generations = [generations_with_error_metadata] await run_manager.on_llm_error( e, - response=LLMResult(generations=generations), # type: ignore[arg-type] + response=LLMResult(generations=generations), ) raise @@ -842,17 +842,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): run_managers[i].on_llm_error( e, response=LLMResult( - generations=[generations_with_error_metadata] # type: ignore[list-item] + generations=[generations_with_error_metadata] ), ) raise flattened_outputs = [ - LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item] + LLMResult(generations=[res.generations], llm_output=res.llm_output) for res in results ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) generations = [res.generations for res in results] - output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type] + output = LLMResult(generations=generations, llm_output=llm_output) if run_managers: run_infos = [] for manager, flattened_output in zip(run_managers, flattened_outputs): @@ -962,7 +962,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): await run_managers[i].on_llm_error( res, response=LLMResult( - generations=[generations_with_error_metadata] # type: ignore[list-item] + generations=[generations_with_error_metadata] ), ) exceptions.append(res) @@ -972,7 +972,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): *[ run_manager.on_llm_end( LLMResult( - generations=[res.generations], # type: ignore[list-item, union-attr] + generations=[res.generations], # type: ignore[union-attr] llm_output=res.llm_output, # type: ignore[union-attr] ) ) @@ -982,12 +982,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC): ) raise exceptions[0] flattened_outputs = [ - LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr] + LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[union-attr] for res in results ] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr] generations = [res.generations for res in results] # type: ignore[union-attr] - output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type] + output = LLMResult(generations=generations, llm_output=llm_output) await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) diff --git a/libs/core/langchain_core/prompts/chat.py b/libs/core/langchain_core/prompts/chat.py index cfc0a0b38df..489cbe7a703 100644 --- a/libs/core/langchain_core/prompts/chat.py +++ b/libs/core/langchain_core/prompts/chat.py @@ -155,9 +155,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate): """ # mypy can't detect the init which is defined in the parent class # b/c these are BaseModel classes. - super().__init__( # type: ignore[call-arg] - variable_name=variable_name, optional=optional, **kwargs - ) + super().__init__(variable_name=variable_name, optional=optional, **kwargs) def format_messages(self, **kwargs: Any) -> list[BaseMessage]: """Format messages from kwargs. diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 38da82fcda3..7b872783909 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -2819,7 +2819,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): if len(steps_flat) < 2: msg = f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}" raise ValueError(msg) - super().__init__( # type: ignore[call-arg] + super().__init__( first=steps_flat[0], middle=list(steps_flat[1:-1]), last=steps_flat[-1], @@ -3612,7 +3612,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]): """ merged = {**steps__} if steps__ is not None else {} merged.update(kwargs) - super().__init__( # type: ignore[call-arg] + super().__init__( steps__={key: coerce_to_runnable(r) for key, r in merged.items()} ) @@ -5325,7 +5325,7 @@ class RunnableEach(RunnableEachBase[Input, Output]): ) -class RunnableBindingBase(RunnableSerializable[Input, Output]): +class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[no-redef] """``Runnable`` that delegates calls to another ``Runnable`` with a set of kwargs. Use only if creating a new ``RunnableBinding`` subclass with different ``__init__`` @@ -5404,7 +5404,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): ``Runnable`` with a custom type. Defaults to None. **other_kwargs: Unpacked into the base class. """ # noqa: E501 - super().__init__( # type: ignore[call-arg] + super().__init__( bound=bound, kwargs=kwargs or {}, config=config or {}, @@ -5729,7 +5729,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]): yield item -class RunnableBinding(RunnableBindingBase[Input, Output]): +class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-redef] """Wrap a ``Runnable`` with additional functionality. A ``RunnableBinding`` can be thought of as a "runnable decorator" that diff --git a/libs/core/langchain_core/runnables/branch.py b/libs/core/langchain_core/runnables/branch.py index dffbc79310a..f0975523e15 100644 --- a/libs/core/langchain_core/runnables/branch.py +++ b/libs/core/langchain_core/runnables/branch.py @@ -136,7 +136,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): super().__init__( branches=branches_, default=default_, - ) # type: ignore[call-arg] + ) model_config = ConfigDict( arbitrary_types_allowed=True, diff --git a/libs/core/langchain_core/runnables/history.py b/libs/core/langchain_core/runnables/history.py index 7ae97bcd711..42a206b9ae7 100644 --- a/libs/core/langchain_core/runnables/history.py +++ b/libs/core/langchain_core/runnables/history.py @@ -38,7 +38,7 @@ MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] -class RunnableWithMessageHistory(RunnableBindingBase): +class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef] """Runnable that manages chat message history for another Runnable. A chat message history is a sequence of messages that represent a conversation. diff --git a/libs/core/langchain_core/runnables/passthrough.py b/libs/core/langchain_core/runnables/passthrough.py index bed4ff8a820..7ca0634dfe9 100644 --- a/libs/core/langchain_core/runnables/passthrough.py +++ b/libs/core/langchain_core/runnables/passthrough.py @@ -186,7 +186,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]): afunc = func func = None - super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg] + super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) @classmethod @override @@ -406,7 +406,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]): mapper: A ``RunnableParallel`` instance that will be used to transform the input dictionary. """ - super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] + super().__init__(mapper=mapper, **kwargs) @classmethod @override @@ -710,7 +710,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]): Args: keys: A single key or a list of keys to pick from the input dictionary. """ - super().__init__(keys=keys, **kwargs) # type: ignore[call-arg] + super().__init__(keys=keys, **kwargs) @classmethod @override diff --git a/libs/core/langchain_core/runnables/retry.py b/libs/core/langchain_core/runnables/retry.py index d495c59e6e9..e909b9b175b 100644 --- a/libs/core/langchain_core/runnables/retry.py +++ b/libs/core/langchain_core/runnables/retry.py @@ -47,7 +47,7 @@ class ExponentialJitterParams(TypedDict, total=False): """Random additional wait sampled from random.uniform(0, jitter).""" -class RunnableRetry(RunnableBindingBase[Input, Output]): +class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-redef] """Retry a Runnable if it fails. RunnableRetry can be used to add retry logic to any object diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index c6af3168731..7cc947fe9a8 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -87,7 +87,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): Args: runnables: A mapping of keys to Runnables. """ - super().__init__( # type: ignore[call-arg] + super().__init__( runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} ) diff --git a/libs/core/langchain_core/structured_query.py b/libs/core/langchain_core/structured_query.py index 5a1a1eb9667..61fb8f0e53a 100644 --- a/libs/core/langchain_core/structured_query.py +++ b/libs/core/langchain_core/structured_query.py @@ -143,7 +143,7 @@ class Comparison(FilterDirective): value: The value to compare to. """ # super exists from BaseModel - super().__init__( # type: ignore[call-arg] + super().__init__( comparator=comparator, attribute=attribute, value=value, **kwargs ) @@ -166,9 +166,7 @@ class Operation(FilterDirective): arguments: The arguments to the operator. """ # super exists from BaseModel - super().__init__( # type: ignore[call-arg] - operator=operator, arguments=arguments, **kwargs - ) + super().__init__(operator=operator, arguments=arguments, **kwargs) class StructuredQuery(Expr): @@ -196,6 +194,4 @@ class StructuredQuery(Expr): limit: The limit on the number of results. """ # super exists from BaseModel - super().__init__( # type: ignore[call-arg] - query=query, filter=filter, limit=limit, **kwargs - ) + super().__init__(query=query, filter=filter, limit=limit, **kwargs) diff --git a/libs/core/langchain_core/tools/structured.py b/libs/core/langchain_core/tools/structured.py index a419a1ede62..7fcc3c26e5a 100644 --- a/libs/core/langchain_core/tools/structured.py +++ b/libs/core/langchain_core/tools/structured.py @@ -228,7 +228,7 @@ class StructuredTool(BaseTool): name=name, func=func, coroutine=coroutine, - args_schema=args_schema, # type: ignore[arg-type] + args_schema=args_schema, description=description_, return_direct=return_direct, response_format=response_format, diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index 4917b2a52a9..d643514e77c 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -64,6 +64,7 @@ langchain-text-splitters = { path = "../text-splitters" } [tool.mypy] +plugins = ["pydantic.mypy"] strict = "True" strict_bytes = "True" enable_error_code = "deprecated" diff --git a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py index 74862a8386a..7258f5bd20f 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py +++ b/libs/core/tests/unit_tests/output_parsers/test_openai_tools.py @@ -810,7 +810,7 @@ def test_parse_with_different_pydantic_2_v1() -> None: # Can't get pydantic to work here due to the odd typing of tryig to support # both v1 and v2 in the same codebase. - parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item] + parser = PydanticToolsParser(tools=[Forecast]) message = AIMessage( content="", tool_calls=[ diff --git a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py index 0486878749a..a72aa926780 100644 --- a/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py +++ b/libs/core/tests/unit_tests/output_parsers/test_pydantic_parser.py @@ -13,7 +13,7 @@ from langchain_core.language_models import ParrotFakeChatModel from langchain_core.output_parsers import PydanticOutputParser from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.prompts.prompt import PromptTemplate -from langchain_core.utils.pydantic import TBaseModel +from langchain_core.utils.pydantic import PydanticBaseModel, TBaseModel class ForecastV2(pydantic.BaseModel): @@ -43,7 +43,7 @@ def test_pydantic_parser_chaining( model = ParrotFakeChatModel() - parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var] + parser = PydanticOutputParser[PydanticBaseModel](pydantic_object=pydantic_object) chain = prompt | model | parser res = chain.invoke({}) @@ -66,7 +66,9 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None: model = ParrotFakeChatModel() - parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] + parser: PydanticOutputParser[PydanticBaseModel] = PydanticOutputParser( + pydantic_object=pydantic_object + ) chain = bad_prompt | model | parser with pytest.raises(OutputParserException): chain.invoke({}) @@ -88,7 +90,7 @@ def test_json_parser_chaining( model = ParrotFakeChatModel() - parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type] + parser = JsonOutputParser(pydantic_object=pydantic_object) chain = prompt | model | parser res = chain.invoke({}) @@ -171,7 +173,7 @@ def test_pydantic_output_parser_type_inference() -> None: # Ignoring mypy error that appears in python 3.8, but not 3.11. # This seems to be functionally correct, so we'll ignore the error. - pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) + pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel) schema = pydantic_parser.get_output_schema().model_json_schema() assert schema == { @@ -202,5 +204,5 @@ def test_format_instructions_preserves_language() -> None: ) ) - parser = PydanticOutputParser(pydantic_object=Foo) + parser = PydanticOutputParser[Foo](pydantic_object=Foo) assert description in parser.get_format_instructions() diff --git a/libs/core/tests/unit_tests/prompts/test_prompt.py b/libs/core/tests/unit_tests/prompts/test_prompt.py index e092eb66581..09d26438cd4 100644 --- a/libs/core/tests/unit_tests/prompts/test_prompt.py +++ b/libs/core/tests/unit_tests/prompts/test_prompt.py @@ -348,7 +348,7 @@ def test_prompt_invalid_template_format() -> None: PromptTemplate( input_variables=input_variables, template=template, - template_format="bar", # type: ignore[arg-type] + template_format="bar", ) diff --git a/libs/core/tests/unit_tests/runnables/test_configurable.py b/libs/core/tests/unit_tests/runnables/test_configurable.py index be4bf4759e5..86bf14adf16 100644 --- a/libs/core/tests/unit_tests/runnables/test_configurable.py +++ b/libs/core/tests/unit_tests/runnables/test_configurable.py @@ -73,7 +73,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]): def test_doubly_set_configurable() -> None: """Test that setting a configurable field with a default value works.""" - runnable = MyRunnable(my_property="a") # type: ignore[call-arg] + runnable = MyRunnable(my_property="a") configurable_runnable = runnable.configurable_fields( my_property=ConfigurableField( id="my_property", @@ -86,7 +86,7 @@ def test_doubly_set_configurable() -> None: def test_alias_set_configurable() -> None: - runnable = MyRunnable(my_property="a") # type: ignore[call-arg] + runnable = MyRunnable(my_property="a") configurable_runnable = runnable.configurable_fields( my_property=ConfigurableField( id="my_property_alias", @@ -104,7 +104,7 @@ def test_alias_set_configurable() -> None: def test_field_alias_set_configurable() -> None: - runnable = MyRunnable(my_property_alias="a") + runnable = MyRunnable(my_property_alias="a") # type: ignore[call-arg] configurable_runnable = runnable.configurable_fields( my_property=ConfigurableField( id="my_property", @@ -122,7 +122,7 @@ def test_field_alias_set_configurable() -> None: def test_config_passthrough() -> None: - runnable = MyRunnable(my_property="a") # type: ignore[call-arg] + runnable = MyRunnable(my_property="a") configurable_runnable = runnable.configurable_fields( my_property=ConfigurableField( id="my_property", @@ -158,7 +158,7 @@ def test_config_passthrough() -> None: def test_config_passthrough_nested() -> None: - runnable = MyRunnable(my_property="a") # type: ignore[call-arg] + runnable = MyRunnable(my_property="a") configurable_runnable = runnable.configurable_fields( my_property=ConfigurableField( id="my_property", diff --git a/libs/core/tests/unit_tests/utils/test_utils.py b/libs/core/tests/unit_tests/utils/test_utils.py index 048addf0e28..c38b951da65 100644 --- a/libs/core/tests/unit_tests/utils/test_utils.py +++ b/libs/core/tests/unit_tests/utils/test_utils.py @@ -347,7 +347,7 @@ def test_using_secret_from_env_as_default_factory( secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY")) # Pass the secret as a parameter - foo = Foo(secret="super_secret") # type: ignore[arg-type] + foo = Foo(secret="super_secret") assert foo.secret.get_secret_value() == "super_secret" # Set the environment variable