chore(core): add mypy pydantic plugin (#32604)

This helps to remove a bunch of mypy false positives.
This commit is contained in:
Christophe Bornet 2025-08-19 15:39:53 +02:00 committed by GitHub
parent b470c79f1d
commit 02d6b9106b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 45 additions and 48 deletions

View File

@ -253,7 +253,7 @@ class ContextSet(RunnableSerializable):
"""
if key is not None:
kwargs[key] = value
super().__init__( # type: ignore[call-arg]
super().__init__(
keys={
k: _coerce_set_value(v) if v is not None else None
for k, v in kwargs.items()

View File

@ -277,7 +277,7 @@ class Document(BaseMedia):
"""Pass page_content in as positional or named arg."""
# my-py is complaining that page_content is not defined on the base class.
# Here, we're relying on pydantic base class to handle the validation.
super().__init__(page_content=page_content, **kwargs) # type: ignore[call-arg]
super().__init__(page_content=page_content, **kwargs)
@classmethod
def is_lc_serializable(cls) -> bool:

View File

@ -533,7 +533,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generations = [generations_with_error_metadata]
run_manager.on_llm_error(
e,
response=LLMResult(generations=generations), # type: ignore[arg-type]
response=LLMResult(generations=generations),
)
raise
@ -627,7 +627,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
generations = [generations_with_error_metadata]
await run_manager.on_llm_error(
e,
response=LLMResult(generations=generations), # type: ignore[arg-type]
response=LLMResult(generations=generations),
)
raise
@ -842,17 +842,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
run_managers[i].on_llm_error(
e,
response=LLMResult(
generations=[generations_with_error_metadata] # type: ignore[list-item]
generations=[generations_with_error_metadata]
),
)
raise
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
LLMResult(generations=[res.generations], llm_output=res.llm_output)
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
output = LLMResult(generations=generations, llm_output=llm_output)
if run_managers:
run_infos = []
for manager, flattened_output in zip(run_managers, flattened_outputs):
@ -962,7 +962,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
await run_managers[i].on_llm_error(
res,
response=LLMResult(
generations=[generations_with_error_metadata] # type: ignore[list-item]
generations=[generations_with_error_metadata]
),
)
exceptions.append(res)
@ -972,7 +972,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
*[
run_manager.on_llm_end(
LLMResult(
generations=[res.generations], # type: ignore[list-item, union-attr]
generations=[res.generations], # type: ignore[union-attr]
llm_output=res.llm_output, # type: ignore[union-attr]
)
)
@ -982,12 +982,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
)
raise exceptions[0]
flattened_outputs = [
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr]
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[union-attr]
for res in results
]
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr]
generations = [res.generations for res in results] # type: ignore[union-attr]
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
output = LLMResult(generations=generations, llm_output=llm_output)
await asyncio.gather(
*[
run_manager.on_llm_end(flattened_output)

View File

@ -155,9 +155,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
"""
# mypy can't detect the init which is defined in the parent class
# b/c these are BaseModel classes.
super().__init__( # type: ignore[call-arg]
variable_name=variable_name, optional=optional, **kwargs
)
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
"""Format messages from kwargs.

View File

@ -2819,7 +2819,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
if len(steps_flat) < 2:
msg = f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
raise ValueError(msg)
super().__init__( # type: ignore[call-arg]
super().__init__(
first=steps_flat[0],
middle=list(steps_flat[1:-1]),
last=steps_flat[-1],
@ -3612,7 +3612,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
"""
merged = {**steps__} if steps__ is not None else {}
merged.update(kwargs)
super().__init__( # type: ignore[call-arg]
super().__init__(
steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
)
@ -5325,7 +5325,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
)
class RunnableBindingBase(RunnableSerializable[Input, Output]):
class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[no-redef]
"""``Runnable`` that delegates calls to another ``Runnable`` with a set of kwargs.
Use only if creating a new ``RunnableBinding`` subclass with different ``__init__``
@ -5404,7 +5404,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
``Runnable`` with a custom type. Defaults to None.
**other_kwargs: Unpacked into the base class.
""" # noqa: E501
super().__init__( # type: ignore[call-arg]
super().__init__(
bound=bound,
kwargs=kwargs or {},
config=config or {},
@ -5729,7 +5729,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
yield item
class RunnableBinding(RunnableBindingBase[Input, Output]):
class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
"""Wrap a ``Runnable`` with additional functionality.
A ``RunnableBinding`` can be thought of as a "runnable decorator" that

View File

@ -136,7 +136,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
super().__init__(
branches=branches_,
default=default_,
) # type: ignore[call-arg]
)
model_config = ConfigDict(
arbitrary_types_allowed=True,

View File

@ -38,7 +38,7 @@ MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
class RunnableWithMessageHistory(RunnableBindingBase):
class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
"""Runnable that manages chat message history for another Runnable.
A chat message history is a sequence of messages that represent a conversation.

View File

@ -186,7 +186,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
afunc = func
func = None
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg]
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
@classmethod
@override
@ -406,7 +406,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
mapper: A ``RunnableParallel`` instance that will be used to transform the
input dictionary.
"""
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
super().__init__(mapper=mapper, **kwargs)
@classmethod
@override
@ -710,7 +710,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
Args:
keys: A single key or a list of keys to pick from the input dictionary.
"""
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
super().__init__(keys=keys, **kwargs)
@classmethod
@override

View File

@ -47,7 +47,7 @@ class ExponentialJitterParams(TypedDict, total=False):
"""Random additional wait sampled from random.uniform(0, jitter)."""
class RunnableRetry(RunnableBindingBase[Input, Output]):
class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
"""Retry a Runnable if it fails.
RunnableRetry can be used to add retry logic to any object

View File

@ -87,7 +87,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
Args:
runnables: A mapping of keys to Runnables.
"""
super().__init__( # type: ignore[call-arg]
super().__init__(
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
)

View File

@ -143,7 +143,7 @@ class Comparison(FilterDirective):
value: The value to compare to.
"""
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
super().__init__(
comparator=comparator, attribute=attribute, value=value, **kwargs
)
@ -166,9 +166,7 @@ class Operation(FilterDirective):
arguments: The arguments to the operator.
"""
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
operator=operator, arguments=arguments, **kwargs
)
super().__init__(operator=operator, arguments=arguments, **kwargs)
class StructuredQuery(Expr):
@ -196,6 +194,4 @@ class StructuredQuery(Expr):
limit: The limit on the number of results.
"""
# super exists from BaseModel
super().__init__( # type: ignore[call-arg]
query=query, filter=filter, limit=limit, **kwargs
)
super().__init__(query=query, filter=filter, limit=limit, **kwargs)

View File

@ -228,7 +228,7 @@ class StructuredTool(BaseTool):
name=name,
func=func,
coroutine=coroutine,
args_schema=args_schema, # type: ignore[arg-type]
args_schema=args_schema,
description=description_,
return_direct=return_direct,
response_format=response_format,

View File

@ -64,6 +64,7 @@ langchain-text-splitters = { path = "../text-splitters" }
[tool.mypy]
plugins = ["pydantic.mypy"]
strict = "True"
strict_bytes = "True"
enable_error_code = "deprecated"

View File

@ -810,7 +810,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
parser = PydanticToolsParser(tools=[Forecast])
message = AIMessage(
content="",
tool_calls=[

View File

@ -13,7 +13,7 @@ from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import TBaseModel
from langchain_core.utils.pydantic import PydanticBaseModel, TBaseModel
class ForecastV2(pydantic.BaseModel):
@ -43,7 +43,7 @@ def test_pydantic_parser_chaining(
model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
parser = PydanticOutputParser[PydanticBaseModel](pydantic_object=pydantic_object)
chain = prompt | model | parser
res = chain.invoke({})
@ -66,7 +66,9 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
model = ParrotFakeChatModel()
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
parser: PydanticOutputParser[PydanticBaseModel] = PydanticOutputParser(
pydantic_object=pydantic_object
)
chain = bad_prompt | model | parser
with pytest.raises(OutputParserException):
chain.invoke({})
@ -88,7 +90,7 @@ def test_json_parser_chaining(
model = ParrotFakeChatModel()
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type]
parser = JsonOutputParser(pydantic_object=pydantic_object)
chain = prompt | model | parser
res = chain.invoke({})
@ -171,7 +173,7 @@ def test_pydantic_output_parser_type_inference() -> None:
# Ignoring mypy error that appears in python 3.8, but not 3.11.
# This seems to be functionally correct, so we'll ignore the error.
pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel)
pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel)
schema = pydantic_parser.get_output_schema().model_json_schema()
assert schema == {
@ -202,5 +204,5 @@ def test_format_instructions_preserves_language() -> None:
)
)
parser = PydanticOutputParser(pydantic_object=Foo)
parser = PydanticOutputParser[Foo](pydantic_object=Foo)
assert description in parser.get_format_instructions()

View File

@ -348,7 +348,7 @@ def test_prompt_invalid_template_format() -> None:
PromptTemplate(
input_variables=input_variables,
template=template,
template_format="bar", # type: ignore[arg-type]
template_format="bar",
)

View File

@ -73,7 +73,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
def test_doubly_set_configurable() -> None:
"""Test that setting a configurable field with a default value works."""
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",
@ -86,7 +86,7 @@ def test_doubly_set_configurable() -> None:
def test_alias_set_configurable() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property_alias",
@ -104,7 +104,7 @@ def test_alias_set_configurable() -> None:
def test_field_alias_set_configurable() -> None:
runnable = MyRunnable(my_property_alias="a")
runnable = MyRunnable(my_property_alias="a") # type: ignore[call-arg]
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",
@ -122,7 +122,7 @@ def test_field_alias_set_configurable() -> None:
def test_config_passthrough() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",
@ -158,7 +158,7 @@ def test_config_passthrough() -> None:
def test_config_passthrough_nested() -> None:
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
runnable = MyRunnable(my_property="a")
configurable_runnable = runnable.configurable_fields(
my_property=ConfigurableField(
id="my_property",

View File

@ -347,7 +347,7 @@ def test_using_secret_from_env_as_default_factory(
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
# Pass the secret as a parameter
foo = Foo(secret="super_secret") # type: ignore[arg-type]
foo = Foo(secret="super_secret")
assert foo.secret.get_secret_value() == "super_secret"
# Set the environment variable