chore(core): add mypy pydantic plugin (#32604)

This helps to remove a bunch of mypy false positives.
This commit is contained in:
Christophe Bornet 2025-08-19 15:39:53 +02:00 committed by GitHub
parent b470c79f1d
commit 02d6b9106b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 45 additions and 48 deletions

View File

@ -253,7 +253,7 @@ class ContextSet(RunnableSerializable):
""" """
if key is not None: if key is not None:
kwargs[key] = value kwargs[key] = value
super().__init__( # type: ignore[call-arg] super().__init__(
keys={ keys={
k: _coerce_set_value(v) if v is not None else None k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items() for k, v in kwargs.items()

View File

@ -277,7 +277,7 @@ class Document(BaseMedia):
"""Pass page_content in as positional or named arg.""" """Pass page_content in as positional or named arg."""
# my-py is complaining that page_content is not defined on the base class. # my-py is complaining that page_content is not defined on the base class.
# Here, we're relying on pydantic base class to handle the validation. # Here, we're relying on pydantic base class to handle the validation.
super().__init__(page_content=page_content, **kwargs) # type: ignore[call-arg] super().__init__(page_content=page_content, **kwargs)
@classmethod @classmethod
def is_lc_serializable(cls) -> bool: def is_lc_serializable(cls) -> bool:

View File

@ -533,7 +533,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generations = [generations_with_error_metadata] generations = [generations_with_error_metadata]
run_manager.on_llm_error( run_manager.on_llm_error(
e, e,
response=LLMResult(generations=generations), # type: ignore[arg-type] response=LLMResult(generations=generations),
) )
raise raise
@ -627,7 +627,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generations = [generations_with_error_metadata] generations = [generations_with_error_metadata]
await run_manager.on_llm_error( await run_manager.on_llm_error(
e, e,
response=LLMResult(generations=generations), # type: ignore[arg-type] response=LLMResult(generations=generations),
) )
raise raise
@ -842,17 +842,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_managers[i].on_llm_error( run_managers[i].on_llm_error(
e, e,
response=LLMResult( response=LLMResult(
generations=[generations_with_error_metadata] # type: ignore[list-item] generations=[generations_with_error_metadata]
), ),
) )
raise raise
flattened_outputs = [ flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item] LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results for res in results
] ]
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results] generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type] output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers: if run_managers:
run_infos = [] run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs): for manager, flattened_output in zip(run_managers, flattened_outputs):
@ -962,7 +962,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await run_managers[i].on_llm_error( await run_managers[i].on_llm_error(
res, res,
response=LLMResult( response=LLMResult(
generations=[generations_with_error_metadata] # type: ignore[list-item] generations=[generations_with_error_metadata]
), ),
) )
exceptions.append(res) exceptions.append(res)
@ -972,7 +972,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*[ *[
run_manager.on_llm_end( run_manager.on_llm_end(
LLMResult( LLMResult(
generations=[res.generations], # type: ignore[list-item, union-attr] generations=[res.generations], # type: ignore[union-attr]
llm_output=res.llm_output, # type: ignore[union-attr] llm_output=res.llm_output, # type: ignore[union-attr]
) )
) )
@ -982,12 +982,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
) )
raise exceptions[0] raise exceptions[0]
flattened_outputs = [ flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr] LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[union-attr]
for res in results for res in results
] ]
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr] llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr]
generations = [res.generations for res in results] # type: ignore[union-attr] generations = [res.generations for res in results] # type: ignore[union-attr]
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type] output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather( await asyncio.gather(
*[ *[
run_manager.on_llm_end(flattened_output) run_manager.on_llm_end(flattened_output)

View File

@ -155,9 +155,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
""" """
# mypy can't detect the init which is defined in the parent class # mypy can't detect the init which is defined in the parent class
# b/c these are BaseModel classes. # b/c these are BaseModel classes.
super().__init__( # type: ignore[call-arg] super().__init__(variable_name=variable_name, optional=optional, **kwargs)
variable_name=variable_name, optional=optional, **kwargs
)
def format_messages(self, **kwargs: Any) -> list[BaseMessage]: def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs. """Format messages from kwargs.

View File

@ -2819,7 +2819,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
if len(steps_flat) < 2: if len(steps_flat) < 2:
msg = f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}" msg = f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
raise ValueError(msg) raise ValueError(msg)
super().__init__( # type: ignore[call-arg] super().__init__(
first=steps_flat[0], first=steps_flat[0],
middle=list(steps_flat[1:-1]), middle=list(steps_flat[1:-1]),
last=steps_flat[-1], last=steps_flat[-1],
@ -3612,7 +3612,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
""" """
merged = {**steps__} if steps__ is not None else {} merged = {**steps__} if steps__ is not None else {}
merged.update(kwargs) merged.update(kwargs)
super().__init__( # type: ignore[call-arg] super().__init__(
steps__={key: coerce_to_runnable(r) for key, r in merged.items()} steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
) )
@ -5325,7 +5325,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
) )
class RunnableBindingBase(RunnableSerializable[Input, Output]): class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[no-redef]
"""``Runnable`` that delegates calls to another ``Runnable`` with a set of kwargs. """``Runnable`` that delegates calls to another ``Runnable`` with a set of kwargs.
Use only if creating a new ``RunnableBinding`` subclass with different ``__init__`` Use only if creating a new ``RunnableBinding`` subclass with different ``__init__``
@ -5404,7 +5404,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
``Runnable`` with a custom type. Defaults to None. ``Runnable`` with a custom type. Defaults to None.
**other_kwargs: Unpacked into the base class. **other_kwargs: Unpacked into the base class.
""" # noqa: E501 """ # noqa: E501
super().__init__( # type: ignore[call-arg] super().__init__(
bound=bound, bound=bound,
kwargs=kwargs or {}, kwargs=kwargs or {},
config=config or {}, config=config or {},
@ -5729,7 +5729,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item yield item
class RunnableBinding(RunnableBindingBase[Input, Output]): class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
"""Wrap a ``Runnable`` with additional functionality. """Wrap a ``Runnable`` with additional functionality.
A ``RunnableBinding`` can be thought of as a "runnable decorator" that A ``RunnableBinding`` can be thought of as a "runnable decorator" that

View File

@ -136,7 +136,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
super().__init__( super().__init__(
branches=branches_, branches=branches_,
default=default_, default=default_,
) # type: ignore[call-arg] )
model_config = ConfigDict( model_config = ConfigDict(
arbitrary_types_allowed=True, arbitrary_types_allowed=True,

View File

@ -38,7 +38,7 @@ MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory] GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase): class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
"""Runnable that manages chat message history for another Runnable. """Runnable that manages chat message history for another Runnable.
A chat message history is a sequence of messages that represent a conversation. A chat message history is a sequence of messages that represent a conversation.

View File

@ -186,7 +186,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
afunc = func afunc = func
func = None func = None
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg] super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
@classmethod @classmethod
@override @override
@ -406,7 +406,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
mapper: A ``RunnableParallel`` instance that will be used to transform the mapper: A ``RunnableParallel`` instance that will be used to transform the
input dictionary. input dictionary.
""" """
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg] super().__init__(mapper=mapper, **kwargs)
@classmethod @classmethod
@override @override
@ -710,7 +710,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
Args: Args:
keys: A single key or a list of keys to pick from the input dictionary. keys: A single key or a list of keys to pick from the input dictionary.
""" """
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg] super().__init__(keys=keys, **kwargs)
@classmethod @classmethod
@override @override

View File

@ -47,7 +47,7 @@ class ExponentialJitterParams(TypedDict, total=False):
"""Random additional wait sampled from random.uniform(0, jitter).""" """Random additional wait sampled from random.uniform(0, jitter)."""
class RunnableRetry(RunnableBindingBase[Input, Output]): class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
"""Retry a Runnable if it fails. """Retry a Runnable if it fails.
RunnableRetry can be used to add retry logic to any object RunnableRetry can be used to add retry logic to any object

View File

@ -87,7 +87,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
Args: Args:
runnables: A mapping of keys to Runnables. runnables: A mapping of keys to Runnables.
""" """
super().__init__( # type: ignore[call-arg] super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()} runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
) )

View File

@ -143,7 +143,7 @@ class Comparison(FilterDirective):
value: The value to compare to. value: The value to compare to.
""" """
# super exists from BaseModel # super exists from BaseModel
super().__init__( # type: ignore[call-arg] super().__init__(
comparator=comparator, attribute=attribute, value=value, **kwargs comparator=comparator, attribute=attribute, value=value, **kwargs
) )
@ -166,9 +166,7 @@ class Operation(FilterDirective):
arguments: The arguments to the operator. arguments: The arguments to the operator.
""" """
# super exists from BaseModel # super exists from BaseModel
super().__init__( # type: ignore[call-arg] super().__init__(operator=operator, arguments=arguments, **kwargs)
operator=operator, arguments=arguments, **kwargs
)
class StructuredQuery(Expr): class StructuredQuery(Expr):
@ -196,6 +194,4 @@ class StructuredQuery(Expr):
limit: The limit on the number of results. limit: The limit on the number of results.
""" """
# super exists from BaseModel # super exists from BaseModel
super().__init__( # type: ignore[call-arg] super().__init__(query=query, filter=filter, limit=limit, **kwargs)
query=query, filter=filter, limit=limit, **kwargs
)

View File

@ -228,7 +228,7 @@ class StructuredTool(BaseTool):
name=name, name=name,
func=func, func=func,
coroutine=coroutine, coroutine=coroutine,
args_schema=args_schema, # type: ignore[arg-type] args_schema=args_schema,
description=description_, description=description_,
return_direct=return_direct, return_direct=return_direct,
response_format=response_format, response_format=response_format,

View File

@ -64,6 +64,7 @@ langchain-text-splitters = { path = "../text-splitters" }
[tool.mypy] [tool.mypy]
plugins = ["pydantic.mypy"]
strict = "True" strict = "True"
strict_bytes = "True" strict_bytes = "True"
enable_error_code = "deprecated" enable_error_code = "deprecated"

View File

@ -810,7 +810,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
# Can't get pydantic to work here due to the odd typing of tryig to support # Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase. # both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item] parser = PydanticToolsParser(tools=[Forecast])
message = AIMessage( message = AIMessage(
content="", content="",
tool_calls=[ tool_calls=[

View File

@ -13,7 +13,7 @@ from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers import PydanticOutputParser from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.output_parsers.json import JsonOutputParser from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import TBaseModel from langchain_core.utils.pydantic import PydanticBaseModel, TBaseModel
class ForecastV2(pydantic.BaseModel): class ForecastV2(pydantic.BaseModel):
@ -43,7 +43,7 @@ def test_pydantic_parser_chaining(
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var] parser = PydanticOutputParser[PydanticBaseModel](pydantic_object=pydantic_object)
chain = prompt | model | parser chain = prompt | model | parser
res = chain.invoke({}) res = chain.invoke({})
@ -66,7 +66,9 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated] parser: PydanticOutputParser[PydanticBaseModel] = PydanticOutputParser(
pydantic_object=pydantic_object
)
chain = bad_prompt | model | parser chain = bad_prompt | model | parser
with pytest.raises(OutputParserException): with pytest.raises(OutputParserException):
chain.invoke({}) chain.invoke({})
@ -88,7 +90,7 @@ def test_json_parser_chaining(
model = ParrotFakeChatModel() model = ParrotFakeChatModel()
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type] parser = JsonOutputParser(pydantic_object=pydantic_object)
chain = prompt | model | parser chain = prompt | model | parser
res = chain.invoke({}) res = chain.invoke({})
@ -171,7 +173,7 @@ def test_pydantic_output_parser_type_inference() -> None:
# Ignoring mypy error that appears in python 3.8, but not 3.11. # Ignoring mypy error that appears in python 3.8, but not 3.11.
# This seems to be functionally correct, so we'll ignore the error. # This seems to be functionally correct, so we'll ignore the error.
pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel) pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel)
schema = pydantic_parser.get_output_schema().model_json_schema() schema = pydantic_parser.get_output_schema().model_json_schema()
assert schema == { assert schema == {
@ -202,5 +204,5 @@ def test_format_instructions_preserves_language() -> None:
) )
) )
parser = PydanticOutputParser(pydantic_object=Foo) parser = PydanticOutputParser[Foo](pydantic_object=Foo)
assert description in parser.get_format_instructions() assert description in parser.get_format_instructions()

View File

@ -348,7 +348,7 @@ def test_prompt_invalid_template_format() -> None:
PromptTemplate( PromptTemplate(
input_variables=input_variables, input_variables=input_variables,
template=template, template=template,
template_format="bar", # type: ignore[arg-type] template_format="bar",
) )

View File

@ -73,7 +73,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def test_doubly_set_configurable() -> None: def test_doubly_set_configurable() -> None:
"""Test that setting a configurable field with a default value works.""" """Test that setting a configurable field with a default value works."""
runnable = MyRunnable(my_property="a") # type: ignore[call-arg] runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(
id="my_property", id="my_property",
@ -86,7 +86,7 @@ def test_doubly_set_configurable() -> None:
def test_alias_set_configurable() -> None: def test_alias_set_configurable() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg] runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(
id="my_property_alias", id="my_property_alias",
@ -104,7 +104,7 @@ def test_alias_set_configurable() -> None:
def test_field_alias_set_configurable() -> None: def test_field_alias_set_configurable() -> None:
runnable = MyRunnable(my_property_alias="a") runnable = MyRunnable(my_property_alias="a") # type: ignore[call-arg]
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(
id="my_property", id="my_property",
@ -122,7 +122,7 @@ def test_field_alias_set_configurable() -> None:
def test_config_passthrough() -> None: def test_config_passthrough() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg] runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(
id="my_property", id="my_property",
@ -158,7 +158,7 @@ def test_config_passthrough() -> None:
def test_config_passthrough_nested() -> None: def test_config_passthrough_nested() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg] runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields( configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField( my_property=ConfigurableField(
id="my_property", id="my_property",

View File

@ -347,7 +347,7 @@ def test_using_secret_from_env_as_default_factory(
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY")) secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
# Pass the secret as a parameter # Pass the secret as a parameter
foo = Foo(secret="super_secret") # type: ignore[arg-type] foo = Foo(secret="super_secret")
assert foo.secret.get_secret_value() == "super_secret" assert foo.secret.get_secret_value() == "super_secret"
# Set the environment variable # Set the environment variable