From a92c032ff69ca6ec84451f8e13af1109d1a4f9ca Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sat, 27 Dec 2025 04:35:27 +0100 Subject: [PATCH] style(core): fix mypy no-any-return violations (#34204) * FIxed where possible * Used `cast` when not possible to fix --------- Co-authored-by: Mason Daugherty Co-authored-by: Mason Daugherty --- libs/core/langchain_core/_api/internal.py | 3 +- .../langchain_core/language_models/base.py | 3 +- .../language_models/chat_models.py | 7 ++- .../langchain_core/language_models/llms.py | 5 +- libs/core/langchain_core/messages/base.py | 2 +- libs/core/langchain_core/messages/utils.py | 62 +++++++++++++++---- .../langchain_core/output_parsers/base.py | 5 +- .../langchain_core/output_parsers/pydantic.py | 14 ++++- libs/core/langchain_core/prompts/base.py | 22 +++---- libs/core/langchain_core/prompts/dict.py | 11 ++-- .../prompts/few_shot_with_templates.py | 3 +- libs/core/langchain_core/prompts/image.py | 4 +- libs/core/langchain_core/prompts/message.py | 2 +- libs/core/langchain_core/prompts/string.py | 28 ++++++--- libs/core/langchain_core/runnables/base.py | 8 +-- .../langchain_core/runnables/graph_mermaid.py | 4 +- .../langchain_core/runnables/graph_png.py | 4 +- libs/core/langchain_core/runnables/router.py | 4 +- libs/core/langchain_core/runnables/utils.py | 7 ++- libs/core/langchain_core/tools/base.py | 4 +- libs/core/langchain_core/tools/convert.py | 4 +- libs/core/langchain_core/tracers/context.py | 2 +- .../langchain_core/tracers/event_stream.py | 4 +- .../core/langchain_core/tracers/log_stream.py | 4 +- libs/core/langchain_core/utils/env.py | 4 +- .../langchain_core/utils/function_calling.py | 8 ++- libs/core/langchain_core/utils/json.py | 4 +- libs/core/langchain_core/utils/json_schema.py | 6 +- libs/core/langchain_core/utils/pydantic.py | 25 ++++---- .../core/langchain_core/vectorstores/utils.py | 4 +- libs/core/pyproject.toml | 1 - .../tests/unit_tests/messages/test_utils.py | 4 +- .../unit_tests/prompts/test_structured.py | 3 +- libs/core/tests/unit_tests/pydantic_utils.py | 37 +++++------ .../unit_tests/runnables/test_history.py | 5 +- .../unit_tests/runnables/test_runnable.py | 4 +- .../runnables/test_runnable_events_v1.py | 4 +- .../runnables/test_runnable_events_v2.py | 4 +- .../runnables/test_tracing_interops.py | 14 +++-- .../tests/unit_tests/runnables/test_utils.py | 8 +-- libs/core/tests/unit_tests/test_tools.py | 14 ++--- .../agents/output_parsers/json.py | 4 +- 42 files changed, 226 insertions(+), 143 deletions(-) diff --git a/libs/core/langchain_core/_api/internal.py b/libs/core/langchain_core/_api/internal.py index e5ef4300ff6..5bebb59347e 100644 --- a/libs/core/langchain_core/_api/internal.py +++ b/libs/core/langchain_core/_api/internal.py @@ -1,4 +1,5 @@ import inspect +from typing import cast def is_caller_internal(depth: int = 2) -> bool: @@ -16,7 +17,7 @@ def is_caller_internal(depth: int = 2) -> bool: return False # Directly access the module name from the frame's global variables module_globals = frame.f_globals - caller_module_name = module_globals.get("__name__", "") + caller_module_name = cast("str", module_globals.get("__name__", "")) return caller_module_name.startswith("langchain") finally: del frame diff --git a/libs/core/langchain_core/language_models/base.py b/libs/core/langchain_core/language_models/base.py index 67c23fee529..2a8aeece46a 100644 --- a/libs/core/langchain_core/language_models/base.py +++ b/libs/core/langchain_core/language_models/base.py @@ -12,6 +12,7 @@ from typing import ( Literal, TypeAlias, TypeVar, + cast, ) from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -92,7 +93,7 @@ def _get_token_ids_default_method(text: str) -> list[int]: tokenizer = get_tokenizer() # tokenize the text using the GPT-2 tokenizer - return tokenizer.encode(text) + return cast("list[int]", tokenizer.encode(text)) LanguageModelInput = PromptValue | str | Sequence[MessageLikeRepresentation] diff --git a/libs/core/langchain_core/language_models/chat_models.py b/libs/core/langchain_core/language_models/chat_models.py index c7bf78dc8dc..dbb7577b687 100644 --- a/libs/core/langchain_core/language_models/chat_models.py +++ b/libs/core/langchain_core/language_models/chat_models.py @@ -359,7 +359,10 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): @cached_property def _serialized(self) -> dict[str, Any]: - return dumpd(self) + # self is always a Serializable object in this case, thus the result is + # guaranteed to be a dict since dumps uses the default callback, which uses + # obj.to_json which always returns TypedDict subclasses + return cast("dict[str, Any]", dumpd(self)) # --- Runnable methods --- @@ -462,7 +465,7 @@ class BaseChatModel(BaseLanguageModel[AIMessage], ABC): # Check if a runtime streaming flag has been passed in. if "stream" in kwargs: - return kwargs["stream"] + return bool(kwargs["stream"]) if "streaming" in self.model_fields_set: streaming_value = getattr(self, "streaming", None) diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index a3853e41360..04e5600bd27 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -301,7 +301,10 @@ class BaseLLM(BaseLanguageModel[str], ABC): @functools.cached_property def _serialized(self) -> dict[str, Any]: - return dumpd(self) + # self is always a Serializable object in this case, thus the result is + # guaranteed to be a dict since dumps uses the default callback, which uses + # obj.to_json which always returns TypedDict subclasses + return cast("dict[str, Any]", dumpd(self)) # --- Runnable methods --- diff --git a/libs/core/langchain_core/messages/base.py b/libs/core/langchain_core/messages/base.py index 784a95703c3..841df41c11b 100644 --- a/libs/core/langchain_core/messages/base.py +++ b/libs/core/langchain_core/messages/base.py @@ -302,7 +302,7 @@ class BaseMessage(Serializable): from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415 prompt = ChatPromptTemplate(messages=[self]) - return prompt + other + return prompt.__add__(other) def pretty_repr( self, diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 8146c2d35b5..a0aef8af63c 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -65,14 +65,19 @@ logger = logging.getLogger(__name__) def _get_type(v: Any) -> str: """Get the type associated with the object for serialization purposes.""" if isinstance(v, dict) and "type" in v: - return v["type"] - if hasattr(v, "type"): - return v.type - msg = ( - f"Expected either a dictionary with a 'type' key or an object " - f"with a 'type' attribute. Instead got type {type(v)}." - ) - raise TypeError(msg) + result = v["type"] + elif hasattr(v, "type"): + result = v.type + else: + msg = ( + f"Expected either a dictionary with a 'type' key or an object " + f"with a 'type' attribute. Instead got type {type(v)}." + ) + raise TypeError(msg) + if not isinstance(result, str): + msg = f"Expected 'type' to be a str, got {type(result).__name__}" + raise TypeError(msg) + return result AnyMessage = Annotated[ @@ -215,8 +220,11 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage: ignore_keys = ["type"] if isinstance(chunk, AIMessageChunk): ignore_keys.extend(["tool_call_chunks", "chunk_position"]) - return chunk.__class__.__mro__[1]( - **{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys} + return cast( + "BaseMessage", + chunk.__class__.__mro__[1]( + **{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys} + ), ) @@ -1112,6 +1120,32 @@ def trim_messages( raise ValueError(msg) +_SingleMessage = BaseMessage | str | dict[str, Any] +_T = TypeVar("_T", bound=_SingleMessage) +# A sequence of _SingleMessage that is NOT a bare str +_MultipleMessages = Sequence[_T] + + +@overload +def convert_to_openai_messages( + messages: _SingleMessage, + *, + text_format: Literal["string", "block"] = "string", + include_id: bool = False, + pass_through_unknown_blocks: bool = True, +) -> dict: ... + + +@overload +def convert_to_openai_messages( + messages: _MultipleMessages, + *, + text_format: Literal["string", "block"] = "string", + include_id: bool = False, + pass_through_unknown_blocks: bool = True, +) -> list[dict]: ... + + def convert_to_openai_messages( messages: MessageLikeRepresentation | Sequence[MessageLikeRepresentation], *, @@ -1207,7 +1241,7 @@ def convert_to_openai_messages( err = f"Unrecognized {text_format=}, expected one of 'string' or 'block'." raise ValueError(err) - oai_messages: list = [] + oai_messages: list[dict] = [] if is_single := isinstance(messages, (BaseMessage, dict, str)): messages = [messages] @@ -1774,7 +1808,11 @@ def _get_message_openai_role(message: BaseMessage) -> str: if isinstance(message, ToolMessage): return "tool" if isinstance(message, SystemMessage): - return message.additional_kwargs.get("__openai_role__", "system") + role = message.additional_kwargs.get("__openai_role__", "system") + if not isinstance(role, str): + msg = f"Expected '__openai_role__' to be a str, got {type(role).__name__}" + raise TypeError(msg) + return role if isinstance(message, FunctionMessage): return "function" if isinstance(message, ChatMessage): diff --git a/libs/core/langchain_core/output_parsers/base.py b/libs/core/langchain_core/output_parsers/base.py index 53f5240a96c..18d1d547533 100644 --- a/libs/core/langchain_core/output_parsers/base.py +++ b/libs/core/langchain_core/output_parsers/base.py @@ -9,6 +9,7 @@ from typing import ( Any, Generic, TypeVar, + cast, ) from typing_extensions import override @@ -77,7 +78,7 @@ class BaseGenerationOutputParser( """Return the output type for the parser.""" # even though mypy complains this isn't valid, # it is good enough for pydantic to build the schema from - return T # type: ignore[misc] + return cast("type[T]", T) # type: ignore[misc] @override def invoke( @@ -181,7 +182,7 @@ class BaseOutputParser( if hasattr(base, "__pydantic_generic_metadata__"): metadata = base.__pydantic_generic_metadata__ if "args" in metadata and len(metadata["args"]) > 0: - return metadata["args"][0] + return cast("type[T]", metadata["args"][0]) msg = ( f"Runnable {self.__class__.__name__} doesn't have an inferable OutputType. " diff --git a/libs/core/langchain_core/output_parsers/pydantic.py b/libs/core/langchain_core/output_parsers/pydantic.py index cc3d186dbe7..9a9f4f13111 100644 --- a/libs/core/langchain_core/output_parsers/pydantic.py +++ b/libs/core/langchain_core/output_parsers/pydantic.py @@ -1,7 +1,7 @@ """Output parsers using Pydantic.""" import json -from typing import Annotated, Generic +from typing import Annotated, Generic, Literal, overload import pydantic from pydantic import SkipValidation @@ -42,6 +42,16 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): msg = f"Failed to parse {name} from completion {json_string}. Got: {e}" return OutputParserException(msg, llm_output=json_string) + @overload + def parse_result( + self, result: list[Generation], *, partial: Literal[False] = False + ) -> TBaseModel: ... + + @overload + def parse_result( + self, result: list[Generation], *, partial: bool = False + ) -> TBaseModel | None: ... + def parse_result( self, result: list[Generation], *, partial: bool = False ) -> TBaseModel | None: @@ -77,7 +87,7 @@ class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]): Returns: The parsed Pydantic object. """ - return super().parse(text) + return self.parse_result([Generation(text=text)]) def get_format_instructions(self) -> str: """Return the format instructions for the JSON output. diff --git a/libs/core/langchain_core/prompts/base.py b/libs/core/langchain_core/prompts/base.py index d6dc5c56d61..dbe21834a24 100644 --- a/libs/core/langchain_core/prompts/base.py +++ b/libs/core/langchain_core/prompts/base.py @@ -9,12 +9,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping # noqa: TC003 from functools import cached_property from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Generic, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast import yaml from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -122,7 +117,10 @@ class BasePromptTemplate( @cached_property def _serialized(self) -> dict[str, Any]: - return dumpd(self) + # self is always a Serializable object in this case, thus the result is + # guaranteed to be a dict since dumpd uses the default callback, which uses + # obj.to_json which always returns TypedDict subclasses + return cast("dict[str, Any]", dumpd(self)) @property @override @@ -156,7 +154,7 @@ class BasePromptTemplate( if not isinstance(inner_input, dict): if len(self.input_variables) == 1: var_name = self.input_variables[0] - inner_input = {var_name: inner_input} + inner_input_ = {var_name: inner_input} else: msg = ( @@ -168,12 +166,14 @@ class BasePromptTemplate( message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT ) ) - missing = set(self.input_variables).difference(inner_input) + else: + inner_input_ = inner_input + missing = set(self.input_variables).difference(inner_input_) if missing: msg = ( f"Input to {self.__class__.__name__} is missing variables {missing}. " f" Expected: {self.input_variables}" - f" Received: {list(inner_input.keys())}" + f" Received: {list(inner_input_.keys())}" ) example_key = missing.pop() msg += ( @@ -184,7 +184,7 @@ class BasePromptTemplate( raise KeyError( create_message(message=msg, error_code=ErrorCode.INVALID_PROMPT_INPUT) ) - return inner_input + return inner_input_ def _format_prompt_with_error_handling(self, inner_input: dict) -> PromptValue: inner_input_ = self._validate_input(inner_input) diff --git a/libs/core/langchain_core/prompts/dict.py b/libs/core/langchain_core/prompts/dict.py index 1d6f76384ed..dca8d9e80a3 100644 --- a/libs/core/langchain_core/prompts/dict.py +++ b/libs/core/langchain_core/prompts/dict.py @@ -2,7 +2,7 @@ import warnings from functools import cached_property -from typing import Any, Literal +from typing import Any, Literal, cast from typing_extensions import override @@ -65,7 +65,10 @@ class DictPromptTemplate(RunnableSerializable[dict, dict]): @cached_property def _serialized(self) -> dict[str, Any]: - return dumpd(self) + # self is always a Serializable object in this case, thus the result is + # guaranteed to be a dict since dumpd uses the default callback, which uses + # obj.to_json which always returns TypedDict subclasses + return cast("dict[str, Any]", dumpd(self)) @classmethod def is_lc_serializable(cls) -> bool: @@ -116,7 +119,7 @@ def _insert_input_variables( inputs: dict[str, Any], template_format: Literal["f-string", "mustache"], ) -> dict[str, Any]: - formatted = {} + formatted: dict[str, Any] = {} formatter = DEFAULT_FORMATTER_MAPPING[template_format] for k, v in template.items(): if isinstance(v, str): @@ -132,7 +135,7 @@ def _insert_input_variables( warnings.warn(msg, stacklevel=2) formatted[k] = _insert_input_variables(v, inputs, template_format) elif isinstance(v, (list, tuple)): - formatted_v = [] + formatted_v: list[str | dict[str, Any]] = [] for x in v: if isinstance(x, str): formatted_v.append(formatter(x, **inputs)) diff --git a/libs/core/langchain_core/prompts/few_shot_with_templates.py b/libs/core/langchain_core/prompts/few_shot_with_templates.py index 0693de2205f..1df73bc3229 100644 --- a/libs/core/langchain_core/prompts/few_shot_with_templates.py +++ b/libs/core/langchain_core/prompts/few_shot_with_templates.py @@ -6,6 +6,7 @@ from typing import Any from pydantic import ConfigDict, model_validator from typing_extensions import Self +from langchain_core.example_selectors import BaseExampleSelector from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.string import ( DEFAULT_FORMATTER_MAPPING, @@ -21,7 +22,7 @@ class FewShotPromptWithTemplates(StringPromptTemplate): """Examples to format into the prompt. Either this or example_selector should be provided.""" - example_selector: Any = None + example_selector: BaseExampleSelector | None = None """ExampleSelector to choose the examples to format into the prompt. Either this or examples should be provided.""" diff --git a/libs/core/langchain_core/prompts/image.py b/libs/core/langchain_core/prompts/image.py index c650a032e73..bdd3db31236 100644 --- a/libs/core/langchain_core/prompts/image.py +++ b/libs/core/langchain_core/prompts/image.py @@ -1,6 +1,6 @@ """Image prompt template for a multimodal model.""" -from typing import Any +from typing import Any, Literal, cast from pydantic import Field @@ -125,7 +125,7 @@ class ImagePromptTemplate(BasePromptTemplate[ImageURL]): output: ImageURL = {"url": url} if detail: # Don't check literal values here: let the API check them - output["detail"] = detail + output["detail"] = cast("Literal['auto', 'low', 'high']", detail) return output async def aformat(self, **kwargs: Any) -> ImageURL: diff --git a/libs/core/langchain_core/prompts/message.py b/libs/core/langchain_core/prompts/message.py index fb7dd88b5c2..ae0ccc03530 100644 --- a/libs/core/langchain_core/prompts/message.py +++ b/libs/core/langchain_core/prompts/message.py @@ -92,4 +92,4 @@ class BaseMessagePromptTemplate(Serializable, ABC): from langchain_core.prompts.chat import ChatPromptTemplate # noqa: PLC0415 prompt = ChatPromptTemplate(messages=[self]) - return prompt + other + return prompt.__add__(other) diff --git a/libs/core/langchain_core/prompts/string.py b/libs/core/langchain_core/prompts/string.py index f34b2e294b0..bb255f645a9 100644 --- a/libs/core/langchain_core/prompts/string.py +++ b/libs/core/langchain_core/prompts/string.py @@ -3,11 +3,12 @@ from __future__ import annotations import warnings -from abc import ABC +from abc import ABC, abstractmethod from string import Formatter -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast from pydantic import BaseModel, create_model +from typing_extensions import override from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.prompts.base import BasePromptTemplate @@ -189,17 +190,20 @@ def mustache_schema(template: str) -> type[BaseModel]: return _create_model_recursive("PromptInput", defs) -def _create_model_recursive(name: str, defs: Defs) -> type: - return create_model( # type: ignore[call-overload] - name, - **{ - k: (_create_model_recursive(k, v), None) if v else (type(v), None) - for k, v in defs.items() - }, +def _create_model_recursive(name: str, defs: Defs) -> type[BaseModel]: + return cast( + "type[BaseModel]", + create_model( # type: ignore[call-overload] + name, + **{ + k: (_create_model_recursive(k, v), None) if v else (type(v), None) + for k, v in defs.items() + }, + ), ) -DEFAULT_FORMATTER_MAPPING: dict[str, Callable] = { +DEFAULT_FORMATTER_MAPPING: dict[str, Callable[..., str]] = { "f-string": formatter.format, "mustache": mustache_formatter, "jinja2": jinja2_formatter, @@ -330,6 +334,10 @@ class StringPromptTemplate(BasePromptTemplate, ABC): """ return StringPromptValue(text=await self.aformat(**kwargs)) + @override + @abstractmethod + def format(self, **kwargs: Any) -> str: ... + def pretty_repr( self, html: bool = False, # noqa: FBT001,FBT002 diff --git a/libs/core/langchain_core/runnables/base.py b/libs/core/langchain_core/runnables/base.py index 577d3593712..b3435cfe6c4 100644 --- a/libs/core/langchain_core/runnables/base.py +++ b/libs/core/langchain_core/runnables/base.py @@ -315,7 +315,7 @@ class Runnable(ABC, Generic[Input, Output]): "args" in metadata and len(metadata["args"]) == _RUNNABLE_GENERIC_NUM_ARGS ): - return metadata["args"][0] + return cast("type[Input]", metadata["args"][0]) # If we didn't find a Pydantic model in the parent classes, # then loop through __orig_bases__. This corresponds to @@ -323,7 +323,7 @@ class Runnable(ABC, Generic[Input, Output]): for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) if type_args and len(type_args) == _RUNNABLE_GENERIC_NUM_ARGS: - return type_args[0] + return cast("type[Input]", type_args[0]) msg = ( f"Runnable {self.get_name()} doesn't have an inferable InputType. " @@ -349,12 +349,12 @@ class Runnable(ABC, Generic[Input, Output]): "args" in metadata and len(metadata["args"]) == _RUNNABLE_GENERIC_NUM_ARGS ): - return metadata["args"][1] + return cast("type[Output]", metadata["args"][1]) for cls in self.__class__.__orig_bases__: # type: ignore[attr-defined] type_args = get_args(cls) if type_args and len(type_args) == _RUNNABLE_GENERIC_NUM_ARGS: - return type_args[1] + return cast("type[Output]", type_args[1]) msg = ( f"Runnable {self.get_name()} doesn't have an inferable OutputType. " diff --git a/libs/core/langchain_core/runnables/graph_mermaid.py b/libs/core/langchain_core/runnables/graph_mermaid.py index ddef49b866c..0e500598335 100644 --- a/libs/core/langchain_core/runnables/graph_mermaid.py +++ b/libs/core/langchain_core/runnables/graph_mermaid.py @@ -11,7 +11,7 @@ import time import urllib.parse from dataclasses import asdict from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import yaml @@ -393,7 +393,7 @@ async def _render_mermaid_using_pyppeteer( } ) - img_bytes = await page.screenshot({"fullPage": False}) + img_bytes = cast("bytes", await page.screenshot({"fullPage": False})) await browser.close() if output_file_path is not None: diff --git a/libs/core/langchain_core/runnables/graph_png.py b/libs/core/langchain_core/runnables/graph_png.py index 737f9445d92..97b6b1f21f2 100644 --- a/libs/core/langchain_core/runnables/graph_png.py +++ b/libs/core/langchain_core/runnables/graph_png.py @@ -1,7 +1,7 @@ """Helper class to draw a state graph into a PNG file.""" from itertools import groupby -from typing import Any +from typing import Any, cast from langchain_core.runnables.graph import Graph, LabelsDict @@ -149,7 +149,7 @@ class PngDrawer: # Save the graph as PNG try: - return viz.draw(output_path, format="png", prog="dot") + return cast("bytes | None", viz.draw(output_path, format="png", prog="dot")) finally: viz.close() diff --git a/libs/core/langchain_core/runnables/router.py b/libs/core/langchain_core/runnables/router.py index e993f233320..4b62beba0c3 100644 --- a/libs/core/langchain_core/runnables/router.py +++ b/libs/core/langchain_core/runnables/router.py @@ -151,7 +151,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): raise ValueError(msg) def invoke( - runnable: Runnable, input_: Input, config: RunnableConfig + runnable: Runnable[Input, Output], input_: Input, config: RunnableConfig ) -> Output | Exception: if return_exceptions: try: @@ -188,7 +188,7 @@ class RouterRunnable(RunnableSerializable[RouterInput, Output]): raise ValueError(msg) async def ainvoke( - runnable: Runnable, input_: Input, config: RunnableConfig + runnable: Runnable[Input, Output], input_: Input, config: RunnableConfig ) -> Output | Exception: if return_exceptions: try: diff --git a/libs/core/langchain_core/runnables/utils.py b/libs/core/langchain_core/runnables/utils.py index 05c42229782..103226c4993 100644 --- a/libs/core/langchain_core/runnables/utils.py +++ b/libs/core/langchain_core/runnables/utils.py @@ -132,9 +132,12 @@ def asyncio_accepts_context() -> bool: return sys.version_info >= (3, 11) +_T = TypeVar("_T") + + def coro_with_context( - coro: Awaitable[Any], context: Context, *, create_task: bool = False -) -> Awaitable[Any]: + coro: Awaitable[_T], context: Context, *, create_task: bool = False +) -> Awaitable[_T]: """Await a coroutine with a context. Args: diff --git a/libs/core/langchain_core/tools/base.py b/libs/core/langchain_core/tools/base.py index e5ad58b8862..5d86a4cf3a3 100644 --- a/libs/core/langchain_core/tools/base.py +++ b/libs/core/langchain_core/tools/base.py @@ -568,7 +568,7 @@ class ChildTool(BaseTool): else: input_schema = self.get_input_schema() json_schema = input_schema.model_json_schema() - return json_schema["properties"] + return cast("dict", json_schema["properties"]) @property def tool_call_schema(self) -> ArgsSchema: @@ -1545,7 +1545,7 @@ def _replace_type_vars( _replace_type_vars(arg, generic_map, default_to_bound=default_to_bound) for arg in args ) - return _py_38_safe_origin(origin)[new_args] # type: ignore[index] + return cast("type", _py_38_safe_origin(origin)[new_args]) # type: ignore[index] return type_ diff --git a/libs/core/langchain_core/tools/convert.py b/libs/core/langchain_core/tools/convert.py index be66424c723..3f6357698f2 100644 --- a/libs/core/langchain_core/tools/convert.py +++ b/libs/core/langchain_core/tools/convert.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable -from typing import Any, Literal, get_type_hints, overload +from typing import Any, Literal, cast, get_type_hints, overload from pydantic import BaseModel, Field, create_model @@ -407,7 +407,7 @@ def _get_schema_from_runnable_and_arg_types( ) raise TypeError(msg) from e fields = {key: (key_type, Field(...)) for key, key_type in arg_types.items()} - return create_model(name, **fields) # type: ignore[call-overload] + return cast("type[BaseModel]", create_model(name, **fields)) # type: ignore[call-overload] def convert_runnable_to_tool( diff --git a/libs/core/langchain_core/tracers/context.py b/libs/core/langchain_core/tracers/context.py index 982054b1322..fac7b3a0871 100644 --- a/libs/core/langchain_core/tracers/context.py +++ b/libs/core/langchain_core/tracers/context.py @@ -137,7 +137,7 @@ def _get_tracer_project() -> str: tracing_context = ls_rh.get_tracing_context() run_tree = tracing_context["parent"] if run_tree is None and tracing_context["project_name"] is not None: - return tracing_context["project_name"] + return cast("str", tracing_context["project_name"]) return getattr( run_tree, "session_name", diff --git a/libs/core/langchain_core/tracers/event_stream.py b/libs/core/langchain_core/tracers/event_stream.py index 0aa98b373f6..f240026853e 100644 --- a/libs/core/langchain_core/tracers/event_stream.py +++ b/libs/core/langchain_core/tracers/event_stream.py @@ -81,9 +81,9 @@ def _assign_name(name: str | None, serialized: dict[str, Any] | None) -> str: return name if serialized is not None: if "name" in serialized: - return serialized["name"] + return cast("str", serialized["name"]) if "id" in serialized: - return serialized["id"][-1] + return cast("str", serialized["id"][-1]) return "Unnamed" diff --git a/libs/core/langchain_core/tracers/log_stream.py b/libs/core/langchain_core/tracers/log_stream.py index 1240d9eb851..2bc1a529e79 100644 --- a/libs/core/langchain_core/tracers/log_stream.py +++ b/libs/core/langchain_core/tracers/log_stream.py @@ -541,7 +541,7 @@ class LogStreamCallbackHandler(BaseTracer, _StreamingCallbackHandler): def _get_standardized_inputs( run: Run, schema_format: Literal["original", "streaming_events"] -) -> dict[str, Any] | None: +) -> Any: """Extract standardized inputs from a run. Standardizes the inputs based on the type of the runnable used. @@ -570,7 +570,7 @@ def _get_standardized_inputs( # new style chains # These nest an additional 'input' key inside the 'inputs' to make sure - # the input is always a dict. We need to unpack and user the inner value. + # the input is always a dict. We need to unpack and use the inner value. inputs = inputs["input"] # We should try to fix this in Runnables and callbacks/tracers # Runnables should be using a None type here not a placeholder diff --git a/libs/core/langchain_core/utils/env.py b/libs/core/langchain_core/utils/env.py index f19928e8761..2f103ee5403 100644 --- a/libs/core/langchain_core/utils/env.py +++ b/libs/core/langchain_core/utils/env.py @@ -46,10 +46,10 @@ def get_from_dict_or_env( if isinstance(key, (list, tuple)): for k in key: if value := data.get(k): - return value + return str(value) if isinstance(key, str) and key in data and data[key]: - return data[key] + return str(data[key]) key_for_err = key[0] if isinstance(key, (list, tuple)) else key diff --git a/libs/core/langchain_core/utils/function_calling.py b/libs/core/langchain_core/utils/function_calling.py index 4e1fc3b46a0..e341191d972 100644 --- a/libs/core/langchain_core/utils/function_calling.py +++ b/libs/core/langchain_core/utils/function_calling.py @@ -237,7 +237,7 @@ _MAX_TYPED_DICT_RECURSION = 25 def _convert_any_typed_dicts_to_pydantic( type_: type, *, - visited: dict, + visited: dict[type, type], depth: int = 0, ) -> type: if type_ in visited: @@ -288,7 +288,9 @@ def _convert_any_typed_dicts_to_pydantic( if arg_desc := arg_descriptions.get(arg): field_kwargs["description"] = arg_desc fields[arg] = (new_arg_type, Field_v1(**field_kwargs)) - model = create_model_v1(typed_dict.__name__, **fields) + model = cast( + "type[BaseModelV1]", create_model_v1(typed_dict.__name__, **fields) + ) model.__doc__ = description visited[typed_dict] = model return model @@ -298,7 +300,7 @@ def _convert_any_typed_dicts_to_pydantic( _convert_any_typed_dicts_to_pydantic(arg, depth=depth + 1, visited=visited) for arg in type_args ) - return subscriptable_origin[type_args] # type: ignore[index] + return cast("type", subscriptable_origin[type_args]) # type: ignore[index] return type_ diff --git a/libs/core/langchain_core/utils/json.py b/libs/core/langchain_core/utils/json.py index 0d5907b4d6e..308c2740457 100644 --- a/libs/core/langchain_core/utils/json.py +++ b/libs/core/langchain_core/utils/json.py @@ -134,7 +134,7 @@ _json_markdown_re = re.compile(r"```(json)?(.*)", re.DOTALL) def parse_json_markdown( json_string: str, *, parser: Callable[[str], Any] = parse_partial_json -) -> dict: +) -> Any: """Parse a JSON string from a Markdown string. Args: @@ -161,7 +161,7 @@ _json_strip_chars = " \n\r\t`" def _parse_json( json_str: str, *, parser: Callable[[str], Any] = parse_partial_json -) -> dict: +) -> Any: # Strip whitespace,newlines,backtick from the start and end json_str = json_str.strip(_json_strip_chars) diff --git a/libs/core/langchain_core/utils/json_schema.py b/libs/core/langchain_core/utils/json_schema.py index 1874e45aec4..72f56fc2efa 100644 --- a/libs/core/langchain_core/utils/json_schema.py +++ b/libs/core/langchain_core/utils/json_schema.py @@ -3,7 +3,7 @@ from __future__ import annotations from copy import deepcopy -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast if TYPE_CHECKING: from collections.abc import Sequence @@ -243,4 +243,6 @@ def dereference_refs( full = full_schema or schema_obj keys_to_skip = list(skip_keys) if skip_keys is not None else ["$defs"] shallow = skip_keys is None - return _dereference_refs_helper(schema_obj, full, None, keys_to_skip, shallow) + return cast( + "dict", _dereference_refs_helper(schema_obj, full, None, keys_to_skip, shallow) + ) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index fb5872c91a2..fcf70615a9d 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -143,7 +143,7 @@ def pre_init(func: Callable) -> Any: # So we keep root_validator for backward compatibility. @root_validator(pre=True) # type: ignore[deprecated] @wraps(func) - def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> dict[str, Any]: + def wrapper(cls: type[BaseModel], values: dict[str, Any]) -> Any: """Decorator to run a function before model initialization. Args: @@ -206,7 +206,7 @@ def _create_subset_model_v1( *, descriptions: dict | None = None, fn_description: str | None = None, -) -> type[BaseModel]: +) -> type[BaseModelV1]: """Create a Pydantic model with only a subset of model's fields.""" fields = {} @@ -223,7 +223,7 @@ def _create_subset_model_v1( field.field_info.description = descriptions[field_name] fields[field_name] = (t, field.field_info) - rtn = create_model_v1(name, **fields) # type: ignore[call-overload] + rtn = cast("type[BaseModelV1]", create_model_v1(name, **fields)) # type: ignore[call-overload] rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "") return rtn @@ -247,8 +247,11 @@ def _create_subset_model_v2( field_info.metadata = field.metadata fields[field_name] = (field.annotation, field_info) - rtn = _create_model_base( # type: ignore[call-overload] - name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) + rtn = cast( + "type[BaseModel]", + _create_model_base( # type: ignore[call-overload] + name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) + ), ) # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. @@ -353,14 +356,12 @@ def _create_root_model( """Create a base class.""" def schema( - cls: type[BaseModel], + cls: type[BaseModelV1], by_alias: bool = True, # noqa: FBT001,FBT002 ref_template: str = DEFAULT_REF_TEMPLATE, ) -> dict[str, Any]: - # Complains about schema not being defined in superclass - schema_ = super(cls, cls).schema( # type: ignore[misc] - by_alias=by_alias, ref_template=ref_template - ) + super_cls = cast("type[BaseModelV1]", super(cls, cls)) + schema_ = super_cls.schema(by_alias=by_alias, ref_template=ref_template) schema_["title"] = name return schema_ @@ -371,8 +372,8 @@ def _create_root_model( schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema, mode: JsonSchemaMode = "validation", ) -> dict[str, Any]: - # Complains about model_json_schema not being defined in superclass - schema_ = super(cls, cls).model_json_schema( # type: ignore[misc] + super_cls = cast("type[BaseModel]", super(cls, cls)) + schema_ = super_cls.model_json_schema( by_alias=by_alias, ref_template=ref_template, schema_generator=schema_generator, diff --git a/libs/core/langchain_core/vectorstores/utils.py b/libs/core/langchain_core/vectorstores/utils.py index 855af9211e7..23f0c40e6b8 100644 --- a/libs/core/langchain_core/vectorstores/utils.py +++ b/libs/core/langchain_core/vectorstores/utils.py @@ -8,7 +8,7 @@ from __future__ import annotations import logging import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast try: import numpy as np @@ -94,7 +94,7 @@ def _cosine_similarity(x: Matrix, y: Matrix) -> np.ndarray: msg = "NaN values found, please remove the NaN values and try again" raise ValueError(msg) from None similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0 - return similarity + return cast("np.ndarray", similarity) x = np.array(x, dtype=np.float32) y = np.array(y, dtype=np.float32) diff --git a/libs/core/pyproject.toml b/libs/core/pyproject.toml index dc06b95019e..730e09cda9e 100644 --- a/libs/core/pyproject.toml +++ b/libs/core/pyproject.toml @@ -76,7 +76,6 @@ enable_error_code = "deprecated" # TODO: activate for 'strict' checking disallow_any_generics = false -warn_return_any = false [tool.ruff.format] diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index 3aa27351c04..7fd61f9a6ee 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -1032,9 +1032,9 @@ def test_convert_to_openai_messages_openai_block() -> None: def test_convert_to_openai_messages_invalid_format() -> None: with pytest.raises(ValueError, match="Unrecognized text_format="): - convert_to_openai_messages( + convert_to_openai_messages( # type: ignore[call-overload] [HumanMessage(content="Hello")], - text_format="invalid", # type: ignore[arg-type] + text_format="invalid", ) diff --git a/libs/core/tests/unit_tests/prompts/test_structured.py b/libs/core/tests/unit_tests/prompts/test_structured.py index 5e6419d8edc..48fb133806c 100644 --- a/libs/core/tests/unit_tests/prompts/test_structured.py +++ b/libs/core/tests/unit_tests/prompts/test_structured.py @@ -13,13 +13,12 @@ from langchain_core.messages import HumanMessage from langchain_core.prompts.structured import StructuredPrompt from langchain_core.runnables.base import Runnable, RunnableLambda from langchain_core.utils.mustache import ChevronError -from langchain_core.utils.pydantic import is_basemodel_subclass def _fake_runnable( _: Any, *, schema: dict | type[BaseModel], value: Any = 42, **_kwargs: Any ) -> BaseModel | dict: - if isclass(schema) and is_basemodel_subclass(schema): + if isclass(schema) and issubclass(schema, BaseModel): return schema(name="yo", value=value) params = cast("dict", schema)["parameters"] return {k: 1 if k != "value" else value for k, v in params.items()} diff --git a/libs/core/tests/unit_tests/pydantic_utils.py b/libs/core/tests/unit_tests/pydantic_utils.py index b8235415343..2f01494e1d7 100644 --- a/libs/core/tests/unit_tests/pydantic_utils.py +++ b/libs/core/tests/unit_tests/pydantic_utils.py @@ -1,8 +1,8 @@ +from inspect import isclass from typing import Any from pydantic import BaseModel - -from langchain_core.utils.pydantic import is_basemodel_subclass +from pydantic.v1 import BaseModel as BaseModelV1 # Function to replace allOf with $ref @@ -75,27 +75,28 @@ def _remove_enum(obj: Any) -> None: def _schema(obj: Any) -> dict: """Return the schema of the object.""" - if not is_basemodel_subclass(obj): - msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}" - raise TypeError(msg) # Remap to old style schema - if not hasattr(obj, "model_json_schema"): # V1 model - return obj.schema() + if isclass(obj): + if issubclass(obj, BaseModelV1): + return obj.schema() + if issubclass(obj, BaseModel): + schema_ = obj.model_json_schema(ref_template="#/definitions/{model}") + if "$defs" in schema_: + schema_["definitions"] = schema_["$defs"] + del schema_["$defs"] - schema_ = obj.model_json_schema(ref_template="#/definitions/{model}") - if "$defs" in schema_: - schema_["definitions"] = schema_["$defs"] - del schema_["$defs"] + if "default" in schema_ and schema_["default"] is None: + del schema_["default"] - if "default" in schema_ and schema_["default"] is None: - del schema_["default"] + replace_all_of_with_ref(schema_) + remove_all_none_default(schema_) + _remove_additionalproperties(schema_) + _remove_enum(schema_) - replace_all_of_with_ref(schema_) - remove_all_none_default(schema_) - _remove_additionalproperties(schema_) - _remove_enum(schema_) + return schema_ - return schema_ + msg = f"Object must be a Pydantic BaseModel subclass. Got {type(obj)}" + raise TypeError(msg) def _remove_additionalproperties(schema: dict) -> dict[str, Any]: diff --git a/libs/core/tests/unit_tests/runnables/test_history.py b/libs/core/tests/unit_tests/runnables/test_history.py index 5a1c3f827aa..b7e8bba566a 100644 --- a/libs/core/tests/unit_tests/runnables/test_history.py +++ b/libs/core/tests/unit_tests/runnables/test_history.py @@ -37,7 +37,7 @@ def test_interfaces() -> None: def _get_get_session_history( *, - store: dict[str, Any] | None = None, + store: dict[str, InMemoryChatMessageHistory] | None = None, ) -> Callable[..., InMemoryChatMessageHistory]: chat_history_store = store if store is not None else {} @@ -96,7 +96,8 @@ async def test_input_messages_async() -> None: output = [ c async for c in with_history.astream([HumanMessage(content="hi again")], config) # type: ignore[arg-type] - ] == ["you said: hello\ngood bye\nhi again"] + ] + assert output == ["you said: hello\ngood bye\nhi again"] assert store == { "1_async": InMemoryChatMessageHistory( messages=[ diff --git a/libs/core/tests/unit_tests/runnables/test_runnable.py b/libs/core/tests/unit_tests/runnables/test_runnable.py index 0d8f5863864..c1ad76ef761 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable.py @@ -3506,7 +3506,7 @@ def test_bind_bind() -> None: def test_bind_with_lambda() -> None: def my_function(_: Any, **kwargs: Any) -> int: - return 3 + kwargs.get("n", 0) + return 3 + int(kwargs.get("n", 0)) runnable = RunnableLambda(my_function).bind(n=1) assert runnable.invoke({}) == 4 @@ -3516,7 +3516,7 @@ def test_bind_with_lambda() -> None: async def test_bind_with_lambda_async() -> None: def my_function(_: Any, **kwargs: Any) -> int: - return 3 + kwargs.get("n", 0) + return 3 + int(kwargs.get("n", 0)) runnable = RunnableLambda(my_function).bind(n=1) assert await runnable.ainvoke({}) == 4 diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py index 2023c826306..4f3a37a331d 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v1.py @@ -2121,7 +2121,7 @@ async def test_async_in_async_stream_lambdas() -> None: async def add_one(x: int) -> int: return x + 1 - add_one_ = RunnableLambda(add_one) # type: ignore[arg-type,var-annotated] + add_one_ = RunnableLambda[int, int](add_one) async def add_one_proxy(x: int, config: RunnableConfig) -> int: # Use sync streaming @@ -2129,7 +2129,7 @@ async def test_async_in_async_stream_lambdas() -> None: results = [result async for result in streaming] return results[0] - add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore[arg-type,var-annotated] + add_one_proxy_ = RunnableLambda[int, int](add_one_proxy) events = await _collect_events(add_one_proxy_.astream_events(1, version="v1")) _assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS) diff --git a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py index f36648d0c0e..d2ef38ac2ec 100644 --- a/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py +++ b/libs/core/tests/unit_tests/runnables/test_runnable_events_v2.py @@ -2088,7 +2088,7 @@ async def test_async_in_async_stream_lambdas() -> None: async def add_one(x: int) -> int: return x + 1 - add_one_ = RunnableLambda(add_one) # type: ignore[arg-type,var-annotated] + add_one_ = RunnableLambda[int, int](add_one) async def add_one_proxy(x: int, config: RunnableConfig) -> int: # Use sync streaming @@ -2096,7 +2096,7 @@ async def test_async_in_async_stream_lambdas() -> None: results = [result async for result in streaming] return results[0] - add_one_proxy_ = RunnableLambda(add_one_proxy) # type: ignore[arg-type,var-annotated] + add_one_proxy_ = RunnableLambda[int, int](add_one_proxy) events = await _collect_events(add_one_proxy_.astream_events(1, version="v2")) _assert_events_equal_allow_superset_metadata(events, EXPECTED_EVENTS) diff --git a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py index 4ed5863ab57..5e7b8f1dd86 100644 --- a/libs/core/tests/unit_tests/runnables/test_tracing_interops.py +++ b/libs/core/tests/unit_tests/runnables/test_tracing_interops.py @@ -281,7 +281,7 @@ class TestRunnableSequenceParallelTraceNesting: return x def after(x: dict) -> int: - return x["chain_result"] + return int(x["chain_result"]) sequence = before | parallel | after if isasyncgenfunction(other_thing): @@ -392,15 +392,21 @@ class TestRunnableSequenceParallelTraceNesting: self._check_posts() @staticmethod - async def ainvoke(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + async def ainvoke( + parent: RunnableLambda[int, int], cb: list[BaseCallbackHandler] + ) -> int: return await parent.ainvoke(1, {"callbacks": cb}) @staticmethod - async def astream(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + async def astream( + parent: RunnableLambda[int, int], cb: list[BaseCallbackHandler] + ) -> int: return [res async for res in parent.astream(1, {"callbacks": cb})][-1] @staticmethod - async def abatch(parent: RunnableLambda, cb: list[BaseCallbackHandler]) -> int: + async def abatch( + parent: RunnableLambda[int, int], cb: list[BaseCallbackHandler] + ) -> int: return (await parent.abatch([1], {"callbacks": cb}))[0] @pytest.mark.skipif( diff --git a/libs/core/tests/unit_tests/runnables/test_utils.py b/libs/core/tests/unit_tests/runnables/test_utils.py index 7eec2c52a7a..4dea8a95556 100644 --- a/libs/core/tests/unit_tests/runnables/test_utils.py +++ b/libs/core/tests/unit_tests/runnables/test_utils.py @@ -37,17 +37,17 @@ def test_indent_lines_after_first(text: str, prefix: str, expected_output: str) assert indented_text == expected_output -global_agent = RunnableLambda(lambda x: x * 3) +global_agent = RunnableLambda[str, str](lambda x: x * 3) def test_nonlocals() -> None: - agent = RunnableLambda(lambda x: x * 2) + agent = RunnableLambda[str, str](lambda x: x * 2) def my_func(value: str, agent: dict[str, str]) -> str: return agent.get("agent_name", value) def my_func2(value: str) -> str: - return agent.get("agent_name", value) # type: ignore[attr-defined] + return str(agent.get("agent_name", value)) # type: ignore[attr-defined] def my_func3(value: str) -> str: return agent.invoke(value) @@ -56,7 +56,7 @@ def test_nonlocals() -> None: return global_agent.invoke(value) def my_func5() -> tuple[Callable[[str], str], RunnableLambda]: - global_agent = RunnableLambda(lambda x: x * 3) + global_agent = RunnableLambda[str, str](lambda x: x * 3) def my_func6(value: str) -> str: return global_agent.invoke(value) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 99c42e83b08..0c62c2d9fa7 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -1880,13 +1880,13 @@ def test_tool_inherited_injected_arg() -> None: def _get_parametrized_tools() -> list: def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str: """my_tool.""" - return some_tool + return "my_tool" async def my_async_tool( x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg] ) -> str: """my_tool.""" - return some_tool + return "my_tool" return [my_tool, my_async_tool] @@ -2311,7 +2311,7 @@ def test_tool_args_schema_pydantic_v2_with_metadata() -> None: @tool(args_schema=Foo) def foo(x) -> list[int]: # type: ignore[no-untyped-def] # noqa: ANN001 """Foo.""" - return x + return x # type: ignore[no-any-return] assert _get_tool_call_json_schema(foo) == { "description": "Foo.", @@ -2493,7 +2493,7 @@ def test_tool_injected_tool_call_id() -> None: @tool def foo(x: int, tool_call_id: Annotated[str, InjectedToolCallId]) -> ToolMessage: """Foo.""" - return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[call-overload] + return ToolMessage(str(x), tool_call_id=tool_call_id) assert foo.invoke( { @@ -2502,7 +2502,7 @@ def test_tool_injected_tool_call_id() -> None: "name": "foo", "id": "bar", } - ) == ToolMessage(0, tool_call_id="bar") # type: ignore[call-overload] + ) == ToolMessage("0", tool_call_id="bar") with pytest.raises( ValueError, @@ -2514,7 +2514,7 @@ def test_tool_injected_tool_call_id() -> None: @tool def foo2(x: int, tool_call_id: Annotated[str, InjectedToolCallId()]) -> ToolMessage: """Foo.""" - return ToolMessage(x, tool_call_id=tool_call_id) # type: ignore[call-overload] + return ToolMessage(str(x), tool_call_id=tool_call_id) assert foo2.invoke( { @@ -2523,7 +2523,7 @@ def test_tool_injected_tool_call_id() -> None: "name": "foo", "id": "bar", } - ) == ToolMessage(0, tool_call_id="bar") # type: ignore[call-overload] + ) == ToolMessage("0", tool_call_id="bar") def test_tool_injected_tool_call_id_override_llm_generated() -> None: diff --git a/libs/langchain/langchain_classic/agents/output_parsers/json.py b/libs/langchain/langchain_classic/agents/output_parsers/json.py index 31a1ca9606d..fc8b4c350d3 100644 --- a/libs/langchain/langchain_classic/agents/output_parsers/json.py +++ b/libs/langchain/langchain_classic/agents/output_parsers/json.py @@ -38,9 +38,9 @@ class JSONAgentOutputParser(AgentOutputParser): def parse(self, text: str) -> AgentAction | AgentFinish: try: response = parse_json_markdown(text) - if isinstance(response, list): # type: ignore[unreachable] + if isinstance(response, list): # gpt turbo frequently ignores the directive to emit a single action - logger.warning("Got multiple action responses: %s", response) # type: ignore[unreachable] + logger.warning("Got multiple action responses: %s", response) response = response[0] if response["action"] == "Final Answer": return AgentFinish({"output": response["action_input"]}, text)