style(core): fix mypy no-any-return violations (#34204)

* FIxed where possible
* Used `cast` when not possible to fix

---------

Co-authored-by: Mason Daugherty <github@mdrxy.com>
Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet
2025-12-27 04:35:27 +01:00
committed by GitHub
parent 88b5f22f1c
commit a92c032ff6
42 changed files with 226 additions and 143 deletions

View File

@@ -9,12 +9,7 @@ from abc import ABC, abstractmethod
from collections.abc import Mapping # noqa: TC003
from functools import cached_property
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Generic,
TypeVar,
)
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
import yaml
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -122,7 +117,10 @@ class BasePromptTemplate(
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumpd uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))
@property
@override
@@ -156,7 +154,7 @@ class BasePromptTemplate(
if not isinstance(inner_input, dict):
if len(self.input_variables) == 1:
var_name = self.input_variables[0]
inner_input = {var_name: inner_input}
inner_input_ = {var_name: inner_input}
else:
msg = (
@@ -168,12 +166,14 @@ class BasePromptTemplate(
message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT
)
)
missing = set(self.input_variables).difference(inner_input)
else:
inner_input_ = inner_input
missing = set(self.input_variables).difference(inner_input_)
if missing:
msg = (
f"Input to {self.__class__.__name__} is missing variables {missing}. "
f" Expected: {self.input_variables}"
f" Received: {list(inner_input.keys())}"
f" Received: {list(inner_input_.keys())}"
)
example_key = missing.pop()
msg += (
@@ -184,7 +184,7 @@ class BasePromptTemplate(
raise KeyError(
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
)
return inner_input
return inner_input_
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
inner_input_ = self._validate_input(inner_input)

View File

@@ -2,7 +2,7 @@
import warnings
from functools import cached_property
from typing import Any, Literal
from typing import Any, Literal, cast
from typing_extensions import override
@@ -65,7 +65,10 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]):
@cached_property
def _serialized(self) -> dict[str, Any]:
return dumpd(self)
# self is always a Serializable object in this case, thus the result is
# guaranteed to be a dict since dumpd uses the default callback, which uses
# obj.to_json which always returns TypedDict subclasses
return cast("dict[str, Any]", dumpd(self))
@classmethod
def is_lc_serializable(cls) -> bool:
@@ -116,7 +119,7 @@ def _insert_input_variables(
inputs: dict[str, Any],
template_format: Literal["f-string", "mustache"],
) -> dict[str, Any]:
formatted = {}
formatted: dict[str, Any] = {}
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
for k, v in template.items():
if isinstance(v, str):
@@ -132,7 +135,7 @@ def _insert_input_variables(
warnings.warn(msg, stacklevel=2)
formatted[k] = _insert_input_variables(v, inputs, template_format)
elif isinstance(v, (list, tuple)):
formatted_v = []
formatted_v: list[str | dict[str, Any]] = []
for x in v:
if isinstance(x, str):
formatted_v.append(formatter(x, **inputs))

View File

@@ -6,6 +6,7 @@ from typing import Any
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from langchain_core.example_selectors import BaseExampleSelector
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import (
DEFAULT_FORMATTER_MAPPING,
@@ -21,7 +22,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
"""Examples to format into the prompt.
Either this or example_selector should be provided."""
example_selector: Any = None
example_selector: BaseExampleSelector | None = None
"""ExampleSelector to choose the examples to format into the prompt.
Either this or examples should be provided."""

View File

@@ -1,6 +1,6 @@
"""Image prompt template for a multimodal model."""
from typing import Any
from typing import Any, Literal, cast
from pydantic import Field
@@ -125,7 +125,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
output: ImageURL = {"url": url}
if detail:
# Don't check literal values here: let the API check them
output["detail"] = detail
output["detail"] = cast("Literal['auto', 'low', 'high']", detail)
return output
async def aformat(self, **kwargs: Any) -> ImageURL:

View File

@@ -92,4 +92,4 @@ class BaseMessagePromptTemplate(Serializable, ABC):
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
prompt = ChatPromptTemplate(messages=[self])
return prompt + other
return prompt.__add__(other)

View File

@@ -3,11 +3,12 @@
from __future__ import annotations
import warnings
from abc import ABC
from abc import ABC, abstractmethod
from string import Formatter
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, create_model
from typing_extensions import override
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
@@ -189,17 +190,20 @@ def mustache_schema(template: str) -> type[BaseModel]:
return _create_model_recursive("PromptInput", defs)
def _create_model_recursive(name: str, defs: Defs) -> type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
for k, v in defs.items()
},
def _create_model_recursive(name: str, defs: Defs) -> type[BaseModel]:
return cast(
"type[BaseModel]",
create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
for k, v in defs.items()
},
),
)
DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
DEFAULT_FORMATTER_MAPPING: dict[str, Callable[..., str]] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
@@ -330,6 +334,10 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
"""
return StringPromptValue(text=await self.aformat(**kwargs))
@override
@abstractmethod
def format(self, **kwargs: Any) -> str: ...
def pretty_repr(
self,
html: bool = False, # noqa: FBT001,FBT002