mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-22 02:45:49 +00:00
chore(core): add mypy pydantic plugin (#32604)
This helps to remove a bunch of mypy false positives.
This commit is contained in:
parent
b470c79f1d
commit
02d6b9106b
@ -253,7 +253,7 @@ class ContextSet(RunnableSerializable):
|
||||
"""
|
||||
if key is not None:
|
||||
kwargs[key] = value
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
keys={
|
||||
k: _coerce_set_value(v) if v is not None else None
|
||||
for k, v in kwargs.items()
|
||||
|
@ -277,7 +277,7 @@ class Document(BaseMedia):
|
||||
"""Pass page_content in as positional or named arg."""
|
||||
# my-py is complaining that page_content is not defined on the base class.
|
||||
# Here, we're relying on pydantic base class to handle the validation.
|
||||
super().__init__(page_content=page_content, **kwargs) # type: ignore[call-arg]
|
||||
super().__init__(page_content=page_content, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
|
@ -533,7 +533,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
generations = [generations_with_error_metadata]
|
||||
run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
||||
response=LLMResult(generations=generations),
|
||||
)
|
||||
raise
|
||||
|
||||
@ -627,7 +627,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
generations = [generations_with_error_metadata]
|
||||
await run_manager.on_llm_error(
|
||||
e,
|
||||
response=LLMResult(generations=generations), # type: ignore[arg-type]
|
||||
response=LLMResult(generations=generations),
|
||||
)
|
||||
raise
|
||||
|
||||
@ -842,17 +842,17 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
run_managers[i].on_llm_error(
|
||||
e,
|
||||
response=LLMResult(
|
||||
generations=[generations_with_error_metadata] # type: ignore[list-item]
|
||||
generations=[generations_with_error_metadata]
|
||||
),
|
||||
)
|
||||
raise
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item]
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output)
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results])
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
if run_managers:
|
||||
run_infos = []
|
||||
for manager, flattened_output in zip(run_managers, flattened_outputs):
|
||||
@ -962,7 +962,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
await run_managers[i].on_llm_error(
|
||||
res,
|
||||
response=LLMResult(
|
||||
generations=[generations_with_error_metadata] # type: ignore[list-item]
|
||||
generations=[generations_with_error_metadata]
|
||||
),
|
||||
)
|
||||
exceptions.append(res)
|
||||
@ -972,7 +972,7 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
*[
|
||||
run_manager.on_llm_end(
|
||||
LLMResult(
|
||||
generations=[res.generations], # type: ignore[list-item, union-attr]
|
||||
generations=[res.generations], # type: ignore[union-attr]
|
||||
llm_output=res.llm_output, # type: ignore[union-attr]
|
||||
)
|
||||
)
|
||||
@ -982,12 +982,12 @@ class BaseChatModel(BaseLanguageModel[BaseMessage], ABC):
|
||||
)
|
||||
raise exceptions[0]
|
||||
flattened_outputs = [
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[list-item, union-attr]
|
||||
LLMResult(generations=[res.generations], llm_output=res.llm_output) # type: ignore[union-attr]
|
||||
for res in results
|
||||
]
|
||||
llm_output = self._combine_llm_outputs([res.llm_output for res in results]) # type: ignore[union-attr]
|
||||
generations = [res.generations for res in results] # type: ignore[union-attr]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output) # type: ignore[arg-type]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
run_manager.on_llm_end(flattened_output)
|
||||
|
@ -155,9 +155,7 @@ class MessagesPlaceholder(BaseMessagePromptTemplate):
|
||||
"""
|
||||
# mypy can't detect the init which is defined in the parent class
|
||||
# b/c these are BaseModel classes.
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
variable_name=variable_name, optional=optional, **kwargs
|
||||
)
|
||||
super().__init__(variable_name=variable_name, optional=optional, **kwargs)
|
||||
|
||||
def format_messages(self, **kwargs: Any) -> list[BaseMessage]:
|
||||
"""Format messages from kwargs.
|
||||
|
@ -2819,7 +2819,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
|
||||
if len(steps_flat) < 2:
|
||||
msg = f"RunnableSequence must have at least 2 steps, got {len(steps_flat)}"
|
||||
raise ValueError(msg)
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
first=steps_flat[0],
|
||||
middle=list(steps_flat[1:-1]),
|
||||
last=steps_flat[-1],
|
||||
@ -3612,7 +3612,7 @@ class RunnableParallel(RunnableSerializable[Input, dict[str, Any]]):
|
||||
"""
|
||||
merged = {**steps__} if steps__ is not None else {}
|
||||
merged.update(kwargs)
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
steps__={key: coerce_to_runnable(r) for key, r in merged.items()}
|
||||
)
|
||||
|
||||
@ -5325,7 +5325,7 @@ class RunnableEach(RunnableEachBase[Input, Output]):
|
||||
)
|
||||
|
||||
|
||||
class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
class RunnableBindingBase(RunnableSerializable[Input, Output]): # type: ignore[no-redef]
|
||||
"""``Runnable`` that delegates calls to another ``Runnable`` with a set of kwargs.
|
||||
|
||||
Use only if creating a new ``RunnableBinding`` subclass with different ``__init__``
|
||||
@ -5404,7 +5404,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
``Runnable`` with a custom type. Defaults to None.
|
||||
**other_kwargs: Unpacked into the base class.
|
||||
""" # noqa: E501
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
bound=bound,
|
||||
kwargs=kwargs or {},
|
||||
config=config or {},
|
||||
@ -5729,7 +5729,7 @@ class RunnableBindingBase(RunnableSerializable[Input, Output]):
|
||||
yield item
|
||||
|
||||
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]):
|
||||
class RunnableBinding(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
|
||||
"""Wrap a ``Runnable`` with additional functionality.
|
||||
|
||||
A ``RunnableBinding`` can be thought of as a "runnable decorator" that
|
||||
|
@ -136,7 +136,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]):
|
||||
super().__init__(
|
||||
branches=branches_,
|
||||
default=default_,
|
||||
) # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True,
|
||||
|
@ -38,7 +38,7 @@ MessagesOrDictWithMessages = Union[Sequence["BaseMessage"], dict[str, Any]]
|
||||
GetSessionHistoryCallable = Callable[..., BaseChatMessageHistory]
|
||||
|
||||
|
||||
class RunnableWithMessageHistory(RunnableBindingBase):
|
||||
class RunnableWithMessageHistory(RunnableBindingBase): # type: ignore[no-redef]
|
||||
"""Runnable that manages chat message history for another Runnable.
|
||||
|
||||
A chat message history is a sequence of messages that represent a conversation.
|
||||
|
@ -186,7 +186,7 @@ class RunnablePassthrough(RunnableSerializable[Other, Other]):
|
||||
afunc = func
|
||||
func = None
|
||||
|
||||
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs) # type: ignore[call-arg]
|
||||
super().__init__(func=func, afunc=afunc, input_type=input_type, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -406,7 +406,7 @@ class RunnableAssign(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
mapper: A ``RunnableParallel`` instance that will be used to transform the
|
||||
input dictionary.
|
||||
"""
|
||||
super().__init__(mapper=mapper, **kwargs) # type: ignore[call-arg]
|
||||
super().__init__(mapper=mapper, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
@ -710,7 +710,7 @@ class RunnablePick(RunnableSerializable[dict[str, Any], dict[str, Any]]):
|
||||
Args:
|
||||
keys: A single key or a list of keys to pick from the input dictionary.
|
||||
"""
|
||||
super().__init__(keys=keys, **kwargs) # type: ignore[call-arg]
|
||||
super().__init__(keys=keys, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
|
@ -47,7 +47,7 @@ class ExponentialJitterParams(TypedDict, total=False):
|
||||
"""Random additional wait sampled from random.uniform(0, jitter)."""
|
||||
|
||||
|
||||
class RunnableRetry(RunnableBindingBase[Input, Output]):
|
||||
class RunnableRetry(RunnableBindingBase[Input, Output]): # type: ignore[no-redef]
|
||||
"""Retry a Runnable if it fails.
|
||||
|
||||
RunnableRetry can be used to add retry logic to any object
|
||||
|
@ -87,7 +87,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]):
|
||||
Args:
|
||||
runnables: A mapping of keys to Runnables.
|
||||
"""
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
runnables={key: coerce_to_runnable(r) for key, r in runnables.items()}
|
||||
)
|
||||
|
||||
|
@ -143,7 +143,7 @@ class Comparison(FilterDirective):
|
||||
value: The value to compare to.
|
||||
"""
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
super().__init__(
|
||||
comparator=comparator, attribute=attribute, value=value, **kwargs
|
||||
)
|
||||
|
||||
@ -166,9 +166,7 @@ class Operation(FilterDirective):
|
||||
arguments: The arguments to the operator.
|
||||
"""
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
operator=operator, arguments=arguments, **kwargs
|
||||
)
|
||||
super().__init__(operator=operator, arguments=arguments, **kwargs)
|
||||
|
||||
|
||||
class StructuredQuery(Expr):
|
||||
@ -196,6 +194,4 @@ class StructuredQuery(Expr):
|
||||
limit: The limit on the number of results.
|
||||
"""
|
||||
# super exists from BaseModel
|
||||
super().__init__( # type: ignore[call-arg]
|
||||
query=query, filter=filter, limit=limit, **kwargs
|
||||
)
|
||||
super().__init__(query=query, filter=filter, limit=limit, **kwargs)
|
||||
|
@ -228,7 +228,7 @@ class StructuredTool(BaseTool):
|
||||
name=name,
|
||||
func=func,
|
||||
coroutine=coroutine,
|
||||
args_schema=args_schema, # type: ignore[arg-type]
|
||||
args_schema=args_schema,
|
||||
description=description_,
|
||||
return_direct=return_direct,
|
||||
response_format=response_format,
|
||||
|
@ -64,6 +64,7 @@ langchain-text-splitters = { path = "../text-splitters" }
|
||||
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ["pydantic.mypy"]
|
||||
strict = "True"
|
||||
strict_bytes = "True"
|
||||
enable_error_code = "deprecated"
|
||||
|
@ -810,7 +810,7 @@ def test_parse_with_different_pydantic_2_v1() -> None:
|
||||
|
||||
# Can't get pydantic to work here due to the odd typing of tryig to support
|
||||
# both v1 and v2 in the same codebase.
|
||||
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
|
||||
parser = PydanticToolsParser(tools=[Forecast])
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
|
@ -13,7 +13,7 @@ from langchain_core.language_models import ParrotFakeChatModel
|
||||
from langchain_core.output_parsers import PydanticOutputParser
|
||||
from langchain_core.output_parsers.json import JsonOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.utils.pydantic import TBaseModel
|
||||
from langchain_core.utils.pydantic import PydanticBaseModel, TBaseModel
|
||||
|
||||
|
||||
class ForecastV2(pydantic.BaseModel):
|
||||
@ -43,7 +43,7 @@ def test_pydantic_parser_chaining(
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[type-var]
|
||||
parser = PydanticOutputParser[PydanticBaseModel](pydantic_object=pydantic_object)
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
@ -66,7 +66,9 @@ def test_pydantic_parser_validation(pydantic_object: TBaseModel) -> None:
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type,var-annotated]
|
||||
parser: PydanticOutputParser[PydanticBaseModel] = PydanticOutputParser(
|
||||
pydantic_object=pydantic_object
|
||||
)
|
||||
chain = bad_prompt | model | parser
|
||||
with pytest.raises(OutputParserException):
|
||||
chain.invoke({})
|
||||
@ -88,7 +90,7 @@ def test_json_parser_chaining(
|
||||
|
||||
model = ParrotFakeChatModel()
|
||||
|
||||
parser = JsonOutputParser(pydantic_object=pydantic_object) # type: ignore[arg-type]
|
||||
parser = JsonOutputParser(pydantic_object=pydantic_object)
|
||||
chain = prompt | model | parser
|
||||
|
||||
res = chain.invoke({})
|
||||
@ -171,7 +173,7 @@ def test_pydantic_output_parser_type_inference() -> None:
|
||||
|
||||
# Ignoring mypy error that appears in python 3.8, but not 3.11.
|
||||
# This seems to be functionally correct, so we'll ignore the error.
|
||||
pydantic_parser = PydanticOutputParser(pydantic_object=SampleModel)
|
||||
pydantic_parser = PydanticOutputParser[SampleModel](pydantic_object=SampleModel)
|
||||
schema = pydantic_parser.get_output_schema().model_json_schema()
|
||||
|
||||
assert schema == {
|
||||
@ -202,5 +204,5 @@ def test_format_instructions_preserves_language() -> None:
|
||||
)
|
||||
)
|
||||
|
||||
parser = PydanticOutputParser(pydantic_object=Foo)
|
||||
parser = PydanticOutputParser[Foo](pydantic_object=Foo)
|
||||
assert description in parser.get_format_instructions()
|
||||
|
@ -348,7 +348,7 @@ def test_prompt_invalid_template_format() -> None:
|
||||
PromptTemplate(
|
||||
input_variables=input_variables,
|
||||
template=template,
|
||||
template_format="bar", # type: ignore[arg-type]
|
||||
template_format="bar",
|
||||
)
|
||||
|
||||
|
||||
|
@ -73,7 +73,7 @@ class MyOtherRunnable(RunnableSerializable[str, str]):
|
||||
|
||||
def test_doubly_set_configurable() -> None:
|
||||
"""Test that setting a configurable field with a default value works."""
|
||||
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
|
||||
runnable = MyRunnable(my_property="a")
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property",
|
||||
@ -86,7 +86,7 @@ def test_doubly_set_configurable() -> None:
|
||||
|
||||
|
||||
def test_alias_set_configurable() -> None:
|
||||
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
|
||||
runnable = MyRunnable(my_property="a")
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property_alias",
|
||||
@ -104,7 +104,7 @@ def test_alias_set_configurable() -> None:
|
||||
|
||||
|
||||
def test_field_alias_set_configurable() -> None:
|
||||
runnable = MyRunnable(my_property_alias="a")
|
||||
runnable = MyRunnable(my_property_alias="a") # type: ignore[call-arg]
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property",
|
||||
@ -122,7 +122,7 @@ def test_field_alias_set_configurable() -> None:
|
||||
|
||||
|
||||
def test_config_passthrough() -> None:
|
||||
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
|
||||
runnable = MyRunnable(my_property="a")
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property",
|
||||
@ -158,7 +158,7 @@ def test_config_passthrough() -> None:
|
||||
|
||||
|
||||
def test_config_passthrough_nested() -> None:
|
||||
runnable = MyRunnable(my_property="a") # type: ignore[call-arg]
|
||||
runnable = MyRunnable(my_property="a")
|
||||
configurable_runnable = runnable.configurable_fields(
|
||||
my_property=ConfigurableField(
|
||||
id="my_property",
|
||||
|
@ -347,7 +347,7 @@ def test_using_secret_from_env_as_default_factory(
|
||||
secret: SecretStr = Field(default_factory=secret_from_env("TEST_KEY"))
|
||||
|
||||
# Pass the secret as a parameter
|
||||
foo = Foo(secret="super_secret") # type: ignore[arg-type]
|
||||
foo = Foo(secret="super_secret")
|
||||
assert foo.secret.get_secret_value() == "super_secret"
|
||||
|
||||
# Set the environment variable
|
||||
|
Loading…
Reference in New Issue
Block a user