mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 10:17:00 +00:00
style(core): fix mypy no-any-return violations (#34204)
* FIxed where possible * Used `cast` when not possible to fix --------- Co-authored-by: Mason Daugherty <github@mdrxy.com> Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
committed by
GitHub
parent
88b5f22f1c
commit
a92c032ff6
@@ -9,12 +9,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping # noqa: TC003
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
@@ -122,7 +117,10 @@ class BasePromptTemplate(
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
return dumpd(self)
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
# guaranteed to be a dict since dumpd uses the default callback, which uses
|
||||
# obj.to_json which always returns TypedDict subclasses
|
||||
return cast("dict[str, Any]", dumpd(self))
|
||||
|
||||
@property
|
||||
@override
|
||||
@@ -156,7 +154,7 @@ class BasePromptTemplate(
|
||||
if not isinstance(inner_input, dict):
|
||||
if len(self.input_variables) == 1:
|
||||
var_name = self.input_variables[0]
|
||||
inner_input = {var_name: inner_input}
|
||||
inner_input_ = {var_name: inner_input}
|
||||
|
||||
else:
|
||||
msg = (
|
||||
@@ -168,12 +166,14 @@ class BasePromptTemplate(
|
||||
message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT
|
||||
)
|
||||
)
|
||||
missing = set(self.input_variables).difference(inner_input)
|
||||
else:
|
||||
inner_input_ = inner_input
|
||||
missing = set(self.input_variables).difference(inner_input_)
|
||||
if missing:
|
||||
msg = (
|
||||
f"Input to {self.__class__.__name__} is missing variables {missing}. "
|
||||
f" Expected: {self.input_variables}"
|
||||
f" Received: {list(inner_input.keys())}"
|
||||
f" Received: {list(inner_input_.keys())}"
|
||||
)
|
||||
example_key = missing.pop()
|
||||
msg += (
|
||||
@@ -184,7 +184,7 @@ class BasePromptTemplate(
|
||||
raise KeyError(
|
||||
create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT)
|
||||
)
|
||||
return inner_input
|
||||
return inner_input_
|
||||
|
||||
def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue:
|
||||
inner_input_ = self._validate_input(inner_input)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import warnings
|
||||
from functools import cached_property
|
||||
from typing import Any, Literal
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
@@ -65,7 +65,10 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]):
|
||||
|
||||
@cached_property
|
||||
def _serialized(self) -> dict[str, Any]:
|
||||
return dumpd(self)
|
||||
# self is always a Serializable object in this case, thus the result is
|
||||
# guaranteed to be a dict since dumpd uses the default callback, which uses
|
||||
# obj.to_json which always returns TypedDict subclasses
|
||||
return cast("dict[str, Any]", dumpd(self))
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
@@ -116,7 +119,7 @@ def _insert_input_variables(
|
||||
inputs: dict[str, Any],
|
||||
template_format: Literal["f-string", "mustache"],
|
||||
) -> dict[str, Any]:
|
||||
formatted = {}
|
||||
formatted: dict[str, Any] = {}
|
||||
formatter = DEFAULT_FORMATTER_MAPPING[template_format]
|
||||
for k, v in template.items():
|
||||
if isinstance(v, str):
|
||||
@@ -132,7 +135,7 @@ def _insert_input_variables(
|
||||
warnings.warn(msg, stacklevel=2)
|
||||
formatted[k] = _insert_input_variables(v, inputs, template_format)
|
||||
elif isinstance(v, (list, tuple)):
|
||||
formatted_v = []
|
||||
formatted_v: list[str | dict[str, Any]] = []
|
||||
for x in v:
|
||||
if isinstance(x, str):
|
||||
formatted_v.append(formatter(x, **inputs))
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from langchain_core.example_selectors import BaseExampleSelector
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.prompts.string import (
|
||||
DEFAULT_FORMATTER_MAPPING,
|
||||
@@ -21,7 +22,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate):
|
||||
"""Examples to format into the prompt.
|
||||
Either this or example_selector should be provided."""
|
||||
|
||||
example_selector: Any = None
|
||||
example_selector: BaseExampleSelector | None = None
|
||||
"""ExampleSelector to choose the examples to format into the prompt.
|
||||
Either this or examples should be provided."""
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Image prompt template for a multimodal model."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -125,7 +125,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]):
|
||||
output: ImageURL = {"url": url}
|
||||
if detail:
|
||||
# Don't check literal values here: let the API check them
|
||||
output["detail"] = detail
|
||||
output["detail"] = cast("Literal['auto', 'low', 'high']", detail)
|
||||
return output
|
||||
|
||||
async def aformat(self, **kwargs: Any) -> ImageURL:
|
||||
|
||||
@@ -92,4 +92,4 @@ class BaseMessagePromptTemplate(Serializable, ABC):
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415
|
||||
|
||||
prompt = ChatPromptTemplate(messages=[self])
|
||||
return prompt + other
|
||||
return prompt.__add__(other)
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from string import Formatter
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, create_model
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.prompts.base import BasePromptTemplate
|
||||
@@ -189,17 +190,20 @@ def mustache_schema(template: str) -> type[BaseModel]:
|
||||
return _create_model_recursive("PromptInput", defs)
|
||||
|
||||
|
||||
def _create_model_recursive(name: str, defs: Defs) -> type:
|
||||
return create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
**{
|
||||
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
|
||||
for k, v in defs.items()
|
||||
},
|
||||
def _create_model_recursive(name: str, defs: Defs) -> type[BaseModel]:
|
||||
return cast(
|
||||
"type[BaseModel]",
|
||||
create_model( # type: ignore[call-overload]
|
||||
name,
|
||||
**{
|
||||
k: (_create_model_recursive(k, v), None) if v else (type(v), None)
|
||||
for k, v in defs.items()
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = {
|
||||
DEFAULT_FORMATTER_MAPPING: dict[str, Callable[..., str]] = {
|
||||
"f-string": formatter.format,
|
||||
"mustache": mustache_formatter,
|
||||
"jinja2": jinja2_formatter,
|
||||
@@ -330,6 +334,10 @@ class StringPromptTemplate(BasePromptTemplate, ABC):
|
||||
"""
|
||||
return StringPromptValue(text=await self.aformat(**kwargs))
|
||||
|
||||
@override
|
||||
@abstractmethod
|
||||
def format(self, **kwargs: Any) -> str: ...
|
||||
|
||||
def pretty_repr(
|
||||
self,
|
||||
html: bool = False, # noqa: FBT001,FBT002
|
||||
|
||||
Reference in New Issue
Block a user