From efdfa00d1094fa63308e9455b37a200ed3a3d025 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Sun, 27 Jul 2025 00:32:34 +0200 Subject: [PATCH] chore(langchain): add ruff rules ARG (#32110) See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg Co-authored-by: Mason Daugherty --- libs/langchain/langchain/agents/agent.py | 31 +++++++---------- libs/langchain/langchain/agents/chat/base.py | 2 ++ .../langchain/agents/conversational/base.py | 2 ++ .../agents/conversational_chat/base.py | 2 ++ libs/langchain/langchain/agents/mrkl/base.py | 2 ++ .../agent_token_buffer_memory.py | 2 ++ libs/langchain/langchain/agents/react/base.py | 4 +++ .../agents/self_ask_with_search/base.py | 3 ++ .../langchain/agents/structured_chat/base.py | 1 + libs/langchain/langchain/agents/tools.py | 3 ++ .../callbacks/streaming_stdout_final_only.py | 3 ++ libs/langchain/langchain/chains/api/base.py | 2 +- .../chains/combine_documents/base.py | 2 +- .../chains/conversational_retrieval/base.py | 3 ++ .../langchain/chains/llm_bash/__init__.py | 2 +- .../chains/llm_symbolic_math/__init__.py | 2 +- libs/langchain/langchain/chains/loading.py | 4 +-- libs/langchain/langchain/chains/moderation.py | 2 ++ .../langchain/chains/qa_with_sources/base.py | 3 ++ .../chains/qa_with_sources/vector_db.py | 2 ++ .../chains/query_constructor/parser.py | 2 +- .../langchain/chains/retrieval_qa/base.py | 2 ++ .../chains/router/embedding_router.py | 3 ++ libs/langchain/langchain/chains/transform.py | 3 ++ .../evaluation/embedding_distance/base.py | 8 +++++ .../langchain/evaluation/exact_match/base.py | 3 ++ .../evaluation/parsing/json_schema.py | 5 +-- .../langchain/evaluation/regex_match/base.py | 3 ++ .../evaluation/string_distance/base.py | 6 ++++ libs/langchain/langchain/memory/buffer.py | 3 ++ .../langchain/memory/buffer_window.py | 2 ++ libs/langchain/langchain/memory/summary.py | 2 ++ .../langchain/memory/summary_buffer.py | 3 ++ .../langchain/memory/token_buffer.py | 2 ++ .../document_compressors/chain_extract.py | 2 +- .../document_compressors/cohere_rerank.py | 2 ++ .../cross_encoder_rerank.py | 2 ++ .../document_compressors/embeddings_filter.py | 3 ++ .../langchain/retrievers/multi_query.py | 2 +- .../langchain/retrievers/multi_vector.py | 3 ++ .../retrievers/time_weighted_retriever.py | 3 ++ .../langchain/smith/evaluation/progress.py | 1 - .../smith/evaluation/string_run_evaluator.py | 2 ++ .../langchain/tools/python/__init__.py | 2 +- libs/langchain/pyproject.toml | 1 + .../cache/fake_embeddings.py | 2 ++ .../tests/unit_tests/agents/test_agent.py | 10 +++--- .../unit_tests/agents/test_agent_async.py | 2 ++ .../unit_tests/agents/test_agent_iterator.py | 2 +- .../unit_tests/agents/test_initialize.py | 2 +- .../tests/unit_tests/agents/test_mrkl.py | 6 ++-- .../agents/test_openai_assistant.py | 2 +- .../callbacks/fake_callback_handler.py | 33 +++++++++++++++++++ .../tests/unit_tests/callbacks/test_file.py | 2 ++ .../tests/unit_tests/callbacks/test_stdout.py | 2 ++ .../tests/unit_tests/chains/test_base.py | 3 ++ .../chains/test_combine_documents.py | 2 +- .../unit_tests/chains/test_conversation.py | 2 ++ .../tests/unit_tests/chains/test_hyde.py | 5 +++ .../unit_tests/chains/test_sequential.py | 3 ++ .../unit_tests/document_loaders/test_base.py | 2 ++ .../unit_tests/embeddings/test_caching.py | 3 ++ .../evaluation/agents/test_eval_chain.py | 2 ++ .../tests/unit_tests/indexes/test_indexing.py | 5 +++ .../tests/unit_tests/llms/fake_chat_model.py | 4 +++ .../tests/unit_tests/llms/fake_llm.py | 2 ++ .../unit_tests/llms/test_fake_chat_model.py | 2 ++ .../unit_tests/output_parsers/test_fix.py | 4 ++- .../unit_tests/output_parsers/test_retry.py | 2 ++ .../retrievers/self_query/test_base.py | 2 ++ .../retrievers/sequential_retriever.py | 11 +++++-- .../unit_tests/retrievers/test_ensemble.py | 2 ++ .../retrievers/test_multi_vector.py | 3 ++ .../retrievers/test_parent_document.py | 3 ++ .../test_time_weighted_retriever.py | 5 +++ .../tests/unit_tests/runnables/test_hub.py | 2 +- .../runnables/test_openai_functions.py | 2 ++ .../smith/evaluation/test_runner_utils.py | 18 ++++------ .../tests/unit_tests/tools/test_render.py | 4 +-- 79 files changed, 241 insertions(+), 62 deletions(-) diff --git a/libs/langchain/langchain/agents/agent.py b/libs/langchain/langchain/agents/agent.py index dab08813249..b2a24b9ebda 100644 --- a/libs/langchain/langchain/agents/agent.py +++ b/libs/langchain/langchain/agents/agent.py @@ -116,8 +116,8 @@ class BaseSingleActionAgent(BaseModel): def return_stopped_response( self, early_stopping_method: str, - intermediate_steps: list[tuple[AgentAction, str]], - **kwargs: Any, + intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002 + **_: Any, ) -> AgentFinish: """Return response when agent has been stopped due to max iterations. @@ -125,7 +125,6 @@ class BaseSingleActionAgent(BaseModel): early_stopping_method: Method to use for early stopping. intermediate_steps: Steps the LLM has taken to date, along with observations. - **kwargs: User inputs. Returns: AgentFinish: Agent finish object. @@ -168,6 +167,7 @@ class BaseSingleActionAgent(BaseModel): """Return Identifier of an agent type.""" raise NotImplementedError + @override def dict(self, **kwargs: Any) -> builtins.dict: """Return dictionary representation of agent. @@ -289,8 +289,8 @@ class BaseMultiActionAgent(BaseModel): def return_stopped_response( self, early_stopping_method: str, - intermediate_steps: list[tuple[AgentAction, str]], - **kwargs: Any, + intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002 + **_: Any, ) -> AgentFinish: """Return response when agent has been stopped due to max iterations. @@ -298,7 +298,6 @@ class BaseMultiActionAgent(BaseModel): early_stopping_method: Method to use for early stopping. intermediate_steps: Steps the LLM has taken to date, along with observations. - **kwargs: User inputs. Returns: AgentFinish: Agent finish object. @@ -317,6 +316,7 @@ class BaseMultiActionAgent(BaseModel): """Return Identifier of an agent type.""" raise NotImplementedError + @override def dict(self, **kwargs: Any) -> builtins.dict: """Return dictionary representation of agent.""" _dict = super().model_dump() @@ -651,6 +651,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent): """ return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) + @override def dict(self, **kwargs: Any) -> builtins.dict: """Return dictionary representation of agent.""" _dict = super().dict() @@ -735,6 +736,7 @@ class Agent(BaseSingleActionAgent): allowed_tools: Optional[list[str]] = None """Allowed tools for the agent. If None, all tools are allowed.""" + @override def dict(self, **kwargs: Any) -> builtins.dict: """Return dictionary representation of agent.""" _dict = super().dict() @@ -750,18 +752,6 @@ class Agent(BaseSingleActionAgent): """Return values of the agent.""" return ["output"] - def _fix_text(self, text: str) -> str: - """Fix the text. - - Args: - text: Text to fix. - - Returns: - str: Fixed text. - """ - msg = "fix_text not implemented for this agent." - raise ValueError(msg) - @property def _stop(self) -> list[str]: return [ @@ -1021,6 +1011,7 @@ class ExceptionTool(BaseTool): description: str = "Exception tool" """Description of the tool.""" + @override def _run( self, query: str, @@ -1028,6 +1019,7 @@ class ExceptionTool(BaseTool): ) -> str: return query + @override async def _arun( self, query: str, @@ -1188,6 +1180,7 @@ class AgentExecutor(Chain): return cast("RunnableAgentType", self.agent) return self.agent + @override def save(self, file_path: Union[Path, str]) -> None: """Raise error - saving not supported for Agent Executors. @@ -1218,7 +1211,7 @@ class AgentExecutor(Chain): callbacks: Callbacks = None, *, include_run_info: bool = False, - async_: bool = False, # arg kept for backwards compat, but ignored + async_: bool = False, # noqa: ARG002 arg kept for backwards compat, but ignored ) -> AgentExecutorIterator: """Enables iteration over steps taken to reach final output. diff --git a/libs/langchain/langchain/agents/chat/base.py b/libs/langchain/langchain/agents/chat/base.py index 883238e2f24..a19fb899c9b 100644 --- a/libs/langchain/langchain/agents/chat/base.py +++ b/libs/langchain/langchain/agents/chat/base.py @@ -13,6 +13,7 @@ from langchain_core.prompts.chat import ( ) from langchain_core.tools import BaseTool from pydantic import Field +from typing_extensions import override from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain.agents.agent import Agent, AgentOutputParser @@ -65,6 +66,7 @@ class ChatAgent(Agent): return agent_scratchpad @classmethod + @override def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return ChatOutputParser() diff --git a/libs/langchain/langchain/agents/conversational/base.py b/libs/langchain/langchain/agents/conversational/base.py index 71aa243c55c..8bba469546d 100644 --- a/libs/langchain/langchain/agents/conversational/base.py +++ b/libs/langchain/langchain/agents/conversational/base.py @@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import PromptTemplate from langchain_core.tools import BaseTool from pydantic import Field +from typing_extensions import override from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain.agents.agent import Agent, AgentOutputParser @@ -35,6 +36,7 @@ class ConversationalAgent(Agent): """Output parser for the agent.""" @classmethod + @override def _get_default_output_parser( cls, ai_prefix: str = "AI", diff --git a/libs/langchain/langchain/agents/conversational_chat/base.py b/libs/langchain/langchain/agents/conversational_chat/base.py index e814cc53784..19b7d4a0fe9 100644 --- a/libs/langchain/langchain/agents/conversational_chat/base.py +++ b/libs/langchain/langchain/agents/conversational_chat/base.py @@ -20,6 +20,7 @@ from langchain_core.prompts.chat import ( ) from langchain_core.tools import BaseTool from pydantic import Field +from typing_extensions import override from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.conversational_chat.output_parser import ConvoOutputParser @@ -42,6 +43,7 @@ class ConversationalChatAgent(Agent): """Template for the tool response.""" @classmethod + @override def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return ConvoOutputParser() diff --git a/libs/langchain/langchain/agents/mrkl/base.py b/libs/langchain/langchain/agents/mrkl/base.py index 9bc04129e93..e26c149154c 100644 --- a/libs/langchain/langchain/agents/mrkl/base.py +++ b/libs/langchain/langchain/agents/mrkl/base.py @@ -12,6 +12,7 @@ from langchain_core.prompts import PromptTemplate from langchain_core.tools import BaseTool, Tool from langchain_core.tools.render import render_text_description from pydantic import Field +from typing_extensions import override from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser @@ -51,6 +52,7 @@ class ZeroShotAgent(Agent): output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser) @classmethod + @override def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return MRKLOutputParser() diff --git a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py index 284e0e72c93..cc403fd62cc 100644 --- a/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py +++ b/libs/langchain/langchain/agents/openai_functions_agent/agent_token_buffer_memory.py @@ -4,6 +4,7 @@ from typing import Any from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, get_buffer_string +from typing_extensions import override from langchain.agents.format_scratchpad import ( format_to_openai_function_messages, @@ -55,6 +56,7 @@ class AgentTokenBufferMemory(BaseChatMemory): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer. diff --git a/libs/langchain/langchain/agents/react/base.py b/libs/langchain/langchain/agents/react/base.py index 33c3099815c..b502e1f3463 100644 --- a/libs/langchain/langchain/agents/react/base.py +++ b/libs/langchain/langchain/agents/react/base.py @@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from langchain_core.tools import BaseTool, Tool from pydantic import Field +from typing_extensions import override from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser @@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent): output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser) @classmethod + @override def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return ReActOutputParser() @@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent): return AgentType.REACT_DOCSTORE @classmethod + @override def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: """Return default prompt.""" return WIKI_PROMPT @@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent): """Agent for the ReAct TextWorld chain.""" @classmethod + @override def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: """Return default prompt.""" return TEXTWORLD_PROMPT diff --git a/libs/langchain/langchain/agents/self_ask_with_search/base.py b/libs/langchain/langchain/agents/self_ask_with_search/base.py index 1078cb0e7a3..aa69f443f87 100644 --- a/libs/langchain/langchain/agents/self_ask_with_search/base.py +++ b/libs/langchain/langchain/agents/self_ask_with_search/base.py @@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.tools import BaseTool, Tool from pydantic import Field +from typing_extensions import override from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent_types import AgentType @@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent): output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser) @classmethod + @override def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: return SelfAskOutputParser() @@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent): return AgentType.SELF_ASK_WITH_SEARCH @classmethod + @override def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: """Prompt does not depend on tools.""" return PROMPT diff --git a/libs/langchain/langchain/agents/structured_chat/base.py b/libs/langchain/langchain/agents/structured_chat/base.py index 320fd8273bc..a00828b7d20 100644 --- a/libs/langchain/langchain/agents/structured_chat/base.py +++ b/libs/langchain/langchain/agents/structured_chat/base.py @@ -71,6 +71,7 @@ class StructuredChatAgent(Agent): pass @classmethod + @override def _get_default_output_parser( cls, llm: Optional[BaseLanguageModel] = None, diff --git a/libs/langchain/langchain/agents/tools.py b/libs/langchain/langchain/agents/tools.py index 98738f5c929..ec50400faf6 100644 --- a/libs/langchain/langchain/agents/tools.py +++ b/libs/langchain/langchain/agents/tools.py @@ -7,6 +7,7 @@ from langchain_core.callbacks import ( CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool, tool +from typing_extensions import override class InvalidTool(BaseTool): @@ -17,6 +18,7 @@ class InvalidTool(BaseTool): description: str = "Called when tool name is invalid. Suggests valid tool names." """Description of the tool.""" + @override def _run( self, requested_tool_name: str, @@ -30,6 +32,7 @@ class InvalidTool(BaseTool): f"try one of [{available_tool_names_str}]." ) + @override async def _arun( self, requested_tool_name: str, diff --git a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py index ddcfa1d4567..17b1fd050f5 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout_final_only.py @@ -4,6 +4,7 @@ import sys from typing import Any, Optional from langchain_core.callbacks import StreamingStdOutCallbackHandler +from typing_extensions import override DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] @@ -63,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): self.stream_prefix = stream_prefix self.answer_reached = False + @override def on_llm_start( self, serialized: dict[str, Any], @@ -72,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): """Run when LLM starts running.""" self.answer_reached = False + @override def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" diff --git a/libs/langchain/langchain/chains/api/base.py b/libs/langchain/langchain/chains/api/base.py index 9a03f142ac6..6be2c29a096 100644 --- a/libs/langchain/langchain/chains/api/base.py +++ b/libs/langchain/langchain/chains/api/base.py @@ -388,7 +388,7 @@ except ImportError: class APIChain: # type: ignore[no-redef] """Raise an ImportError if APIChain is used without langchain_community.""" - def __init__(self, *args: Any, **kwargs: Any) -> None: + def __init__(self, *_: Any, **__: Any) -> None: """Raise an ImportError if APIChain is used without langchain_community.""" msg = ( "To use the APIChain, you must install the langchain_community package." diff --git a/libs/langchain/langchain/chains/combine_documents/base.py b/libs/langchain/langchain/chains/combine_documents/base.py index d0fc8ca77c9..74fd6135f53 100644 --- a/libs/langchain/langchain/chains/combine_documents/base.py +++ b/libs/langchain/langchain/chains/combine_documents/base.py @@ -83,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC): """ return [self.output_key] - def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: + def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002 """Return the prompt length given the documents passed in. This can be used by a caller to determine whether passing in a list diff --git a/libs/langchain/langchain/chains/conversational_retrieval/base.py b/libs/langchain/langchain/chains/conversational_retrieval/base.py index ba01acf7caf..3b4d417bb0c 100644 --- a/libs/langchain/langchain/chains/conversational_retrieval/base.py +++ b/libs/langchain/langchain/chains/conversational_retrieval/base.py @@ -402,6 +402,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): return docs[:num_docs] + @override def _get_docs( self, question: str, @@ -416,6 +417,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain): ) return self._reduce_tokens_below_limit(docs) + @override async def _aget_docs( self, question: str, @@ -512,6 +514,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain): ) return values + @override def _get_docs( self, question: str, diff --git a/libs/langchain/langchain/chains/llm_bash/__init__.py b/libs/langchain/langchain/chains/llm_bash/__init__.py index ea0494dc2cc..4922b1595b5 100644 --- a/libs/langchain/langchain/chains/llm_bash/__init__.py +++ b/libs/langchain/langchain/chains/llm_bash/__init__.py @@ -1,4 +1,4 @@ -def __getattr__(name: str = "") -> None: +def __getattr__(_: str = "") -> None: """Raise an error on import since is deprecated.""" msg = ( "This module has been moved to langchain-experimental. " diff --git a/libs/langchain/langchain/chains/llm_symbolic_math/__init__.py b/libs/langchain/langchain/chains/llm_symbolic_math/__init__.py index ae8f59f6e60..52de5444acb 100644 --- a/libs/langchain/langchain/chains/llm_symbolic_math/__init__.py +++ b/libs/langchain/langchain/chains/llm_symbolic_math/__init__.py @@ -1,4 +1,4 @@ -def __getattr__(name: str = "") -> None: +def __getattr__(_: str = "") -> None: """Raise an error on import since is deprecated.""" msg = ( "This module has been moved to langchain-experimental. " diff --git a/libs/langchain/langchain/chains/loading.py b/libs/langchain/langchain/chains/loading.py index a24365b808b..174f8e3a504 100644 --- a/libs/langchain/langchain/chains/loading.py +++ b/libs/langchain/langchain/chains/loading.py @@ -39,7 +39,7 @@ try: from langchain_community.llms.loading import load_llm, load_llm_from_config except ImportError: - def load_llm(*args: Any, **kwargs: Any) -> None: + def load_llm(*_: Any, **__: Any) -> None: """Import error for load_llm.""" msg = ( "To use this load_llm functionality you must install the " @@ -48,7 +48,7 @@ except ImportError: ) raise ImportError(msg) - def load_llm_from_config(*args: Any, **kwargs: Any) -> None: + def load_llm_from_config(*_: Any, **__: Any) -> None: """Import error for load_llm_from_config.""" msg = ( "To use this load_llm_from_config functionality you must install the " diff --git a/libs/langchain/langchain/chains/moderation.py b/libs/langchain/langchain/chains/moderation.py index e7a81977836..2b687d9e9bd 100644 --- a/libs/langchain/langchain/chains/moderation.py +++ b/libs/langchain/langchain/chains/moderation.py @@ -8,6 +8,7 @@ from langchain_core.callbacks import ( ) from langchain_core.utils import check_package_version, get_from_dict_or_env from pydantic import Field, model_validator +from typing_extensions import override from langchain.chains.base import Chain @@ -105,6 +106,7 @@ class OpenAIModerationChain(Chain): return error_str return text + @override def _call( self, inputs: dict[str, Any], diff --git a/libs/langchain/langchain/chains/qa_with_sources/base.py b/libs/langchain/langchain/chains/qa_with_sources/base.py index d1e115870cc..47396409319 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/base.py +++ b/libs/langchain/langchain/chains/qa_with_sources/base.py @@ -16,6 +16,7 @@ from langchain_core.documents import Document from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts import BasePromptTemplate from pydantic import ConfigDict, model_validator +from typing_extensions import override from langchain.chains import ReduceDocumentsChain from langchain.chains.base import Chain @@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain): """ return [self.input_docs_key, self.question_key] + @override def _get_docs( self, inputs: dict[str, Any], @@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain): """Get docs to run questioning over.""" return inputs.pop(self.input_docs_key) + @override async def _aget_docs( self, inputs: dict[str, Any], diff --git a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py index e758bc87b12..7f821cc3cbe 100644 --- a/libs/langchain/langchain/chains/qa_with_sources/vector_db.py +++ b/libs/langchain/langchain/chains/qa_with_sources/vector_db.py @@ -10,6 +10,7 @@ from langchain_core.callbacks import ( from langchain_core.documents import Document from langchain_core.vectorstores import VectorStore from pydantic import Field, model_validator +from typing_extensions import override from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain @@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain): return docs[:num_docs] + @override def _get_docs( self, inputs: dict[str, Any], diff --git a/libs/langchain/langchain/chains/query_constructor/parser.py b/libs/langchain/langchain/chains/query_constructor/parser.py index e03a26efe13..97bb312d680 100644 --- a/libs/langchain/langchain/chains/query_constructor/parser.py +++ b/libs/langchain/langchain/chains/query_constructor/parser.py @@ -11,7 +11,7 @@ try: from lark import Lark, Transformer, v_args except ImportError: - def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc] + def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc] """Dummy decorator for when lark is not installed.""" return lambda _: None diff --git a/libs/langchain/langchain/chains/retrieval_qa/base.py b/libs/langchain/langchain/chains/retrieval_qa/base.py index 3f9e0630dae..4c6cf5de307 100644 --- a/libs/langchain/langchain/chains/retrieval_qa/base.py +++ b/libs/langchain/langchain/chains/retrieval_qa/base.py @@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore from pydantic import ConfigDict, Field, model_validator +from typing_extensions import override from langchain.chains.base import Chain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain @@ -330,6 +331,7 @@ class VectorDBQA(BaseRetrievalQA): raise ValueError(msg) return values + @override def _get_docs( self, question: str, diff --git a/libs/langchain/langchain/chains/router/embedding_router.py b/libs/langchain/langchain/chains/router/embedding_router.py index 4e9cc94e012..ca2e8218bfa 100644 --- a/libs/langchain/langchain/chains/router/embedding_router.py +++ b/libs/langchain/langchain/chains/router/embedding_router.py @@ -11,6 +11,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore from pydantic import ConfigDict +from typing_extensions import override from langchain.chains.router.base import RouterChain @@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain): """ return self.routing_keys + @override def _call( self, inputs: dict[str, Any], @@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain): results = self.vectorstore.similarity_search(_input, k=1) return {"next_inputs": inputs, "destination": results[0].metadata["name"]} + @override async def _acall( self, inputs: dict[str, Any], diff --git a/libs/langchain/langchain/chains/transform.py b/libs/langchain/langchain/chains/transform.py index a241b8dc165..4f1ac06733c 100644 --- a/libs/langchain/langchain/chains/transform.py +++ b/libs/langchain/langchain/chains/transform.py @@ -10,6 +10,7 @@ from langchain_core.callbacks import ( CallbackManagerForChainRun, ) from pydantic import Field +from typing_extensions import override from langchain.chains.base import Chain @@ -63,6 +64,7 @@ class TransformChain(Chain): """ return self.output_variables + @override def _call( self, inputs: dict[str, str], @@ -70,6 +72,7 @@ class TransformChain(Chain): ) -> dict[str, str]: return self.transform_cb(inputs) + @override async def _acall( self, inputs: dict[str, Any], diff --git a/libs/langchain/langchain/evaluation/embedding_distance/base.py b/libs/langchain/langchain/evaluation/embedding_distance/base.py index 53a67962a54..bf65305114b 100644 --- a/libs/langchain/langchain/evaluation/embedding_distance/base.py +++ b/libs/langchain/langchain/evaluation/embedding_distance/base.py @@ -331,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): """ return ["prediction", "reference"] + @override def _call( self, inputs: dict[str, Any], @@ -355,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): score = self._compute_score(vectors) return {"score": score} + @override async def _acall( self, inputs: dict[str, Any], @@ -382,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): score = self._compute_score(vectors) return {"score": score} + @override def _evaluate_strings( self, *, @@ -416,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator): ) return self._prepare_output(result) + @override async def _aevaluate_strings( self, *, @@ -478,6 +482,7 @@ class PairwiseEmbeddingDistanceEvalChain( """Return the evaluation name.""" return f"pairwise_embedding_{self.distance_metric.value}_distance" + @override def _call( self, inputs: dict[str, Any], @@ -505,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain( score = self._compute_score(vectors) return {"score": score} + @override async def _acall( self, inputs: dict[str, Any], @@ -532,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain( score = self._compute_score(vectors) return {"score": score} + @override def _evaluate_string_pairs( self, *, @@ -567,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain( ) return self._prepare_output(result) + @override async def _aevaluate_string_pairs( self, *, diff --git a/libs/langchain/langchain/evaluation/exact_match/base.py b/libs/langchain/langchain/evaluation/exact_match/base.py index 4ee092b64f0..c7ae0dc1207 100644 --- a/libs/langchain/langchain/evaluation/exact_match/base.py +++ b/libs/langchain/langchain/evaluation/exact_match/base.py @@ -1,6 +1,8 @@ import string from typing import Any +from typing_extensions import override + from langchain.evaluation.schema import StringEvaluator @@ -78,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator): """ return "exact_match" + @override def _evaluate_strings( # type: ignore[override] self, *, diff --git a/libs/langchain/langchain/evaluation/parsing/json_schema.py b/libs/langchain/langchain/evaluation/parsing/json_schema.py index 67840078228..0adbd6140ab 100644 --- a/libs/langchain/langchain/evaluation/parsing/json_schema.py +++ b/libs/langchain/langchain/evaluation/parsing/json_schema.py @@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator): """ # noqa: E501 - def __init__(self, **kwargs: Any) -> None: + def __init__(self, **_: Any) -> None: """Initializes the JsonSchemaEvaluator. - Args: - kwargs: Additional keyword arguments. - Raises: ImportError: If the jsonschema package is not installed. """ diff --git a/libs/langchain/langchain/evaluation/regex_match/base.py b/libs/langchain/langchain/evaluation/regex_match/base.py index d86eb45a233..6ad35f192b0 100644 --- a/libs/langchain/langchain/evaluation/regex_match/base.py +++ b/libs/langchain/langchain/evaluation/regex_match/base.py @@ -1,6 +1,8 @@ import re from typing import Any +from typing_extensions import override + from langchain.evaluation.schema import StringEvaluator @@ -70,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator): """ return "regex_match" + @override def _evaluate_strings( # type: ignore[override] self, *, diff --git a/libs/langchain/langchain/evaluation/string_distance/base.py b/libs/langchain/langchain/evaluation/string_distance/base.py index d4fc2b5d60c..6a62452b9fb 100644 --- a/libs/langchain/langchain/evaluation/string_distance/base.py +++ b/libs/langchain/langchain/evaluation/string_distance/base.py @@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin): """ return f"{self.distance.value}_distance" + @override def _call( self, inputs: dict[str, Any], @@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin): """ return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])} + @override async def _acall( self, inputs: dict[str, Any], @@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi """ return f"pairwise_{self.distance.value}_distance" + @override def _call( self, inputs: dict[str, Any], @@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), } + @override async def _acall( self, inputs: dict[str, Any], @@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), } + @override def _evaluate_string_pairs( self, *, @@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi ) return self._prepare_output(result) + @override async def _aevaluate_string_pairs( self, *, diff --git a/libs/langchain/langchain/memory/buffer.py b/libs/langchain/langchain/memory/buffer.py index 9beffa87d9d..57020000e81 100644 --- a/libs/langchain/langchain/memory/buffer.py +++ b/libs/langchain/langchain/memory/buffer.py @@ -79,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer.""" return {self.memory_key: self.buffer} + @override async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return key-value pairs given the text input to the chain.""" buffer = await self.abuffer() @@ -133,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]: """Return history buffer.""" return {self.memory_key: self.buffer} diff --git a/libs/langchain/langchain/memory/buffer_window.py b/libs/langchain/langchain/memory/buffer_window.py index 8e586bdbc84..32bbb699cb1 100644 --- a/libs/langchain/langchain/memory/buffer_window.py +++ b/libs/langchain/langchain/memory/buffer_window.py @@ -2,6 +2,7 @@ from typing import Any, Union from langchain_core._api import deprecated from langchain_core.messages import BaseMessage, get_buffer_string +from typing_extensions import override from langchain.memory.chat_memory import BaseChatMemory @@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer.""" return {self.memory_key: self.buffer} diff --git a/libs/langchain/langchain/memory/summary.py b/libs/langchain/langchain/memory/summary.py index 31dc4785161..d5518a6e9e1 100644 --- a/libs/langchain/langchain/memory/summary.py +++ b/libs/langchain/langchain/memory/summary.py @@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin from langchain_core.prompts import BasePromptTemplate from langchain_core.utils import pre_init from pydantic import BaseModel +from typing_extensions import override from langchain.chains.llm import LLMChain from langchain.memory.chat_memory import BaseChatMemory @@ -133,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer.""" if self.return_messages: diff --git a/libs/langchain/langchain/memory/summary_buffer.py b/libs/langchain/langchain/memory/summary_buffer.py index 692ea0d81eb..7b82702240d 100644 --- a/libs/langchain/langchain/memory/summary_buffer.py +++ b/libs/langchain/langchain/memory/summary_buffer.py @@ -3,6 +3,7 @@ from typing import Any, Union from langchain_core._api import deprecated from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.utils import pre_init +from typing_extensions import override from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.summary import SummarizerMixin @@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer.""" buffer = self.chat_memory.messages @@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin): ) return {self.memory_key: final_buffer} + @override async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Asynchronously return key-value pairs given the text input to the chain.""" buffer = await self.chat_memory.aget_messages() diff --git a/libs/langchain/langchain/memory/token_buffer.py b/libs/langchain/langchain/memory/token_buffer.py index 527ac7eba6e..f100b7e89a7 100644 --- a/libs/langchain/langchain/memory/token_buffer.py +++ b/libs/langchain/langchain/memory/token_buffer.py @@ -3,6 +3,7 @@ from typing import Any from langchain_core._api import deprecated from langchain_core.language_models import BaseLanguageModel from langchain_core.messages import BaseMessage, get_buffer_string +from typing_extensions import override from langchain.memory.chat_memory import BaseChatMemory @@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory): """ return [self.memory_key] + @override def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: """Return history buffer.""" return {self.memory_key: self.buffer} diff --git a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py index c769d88bbaf..8b7f75bc28a 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py +++ b/libs/langchain/langchain/retrievers/document_compressors/chain_extract.py @@ -110,7 +110,7 @@ class LLMChainExtractor(BaseDocumentCompressor): llm: BaseLanguageModel, prompt: Optional[PromptTemplate] = None, get_input: Optional[Callable[[str, Document], str]] = None, - llm_chain_kwargs: Optional[dict] = None, + llm_chain_kwargs: Optional[dict] = None, # noqa: ARG003 ) -> LLMChainExtractor: """Initialize from LLM.""" _prompt = prompt if prompt is not None else _get_default_chain_prompt() diff --git a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py index ed9e274bb10..4d779c8e3cf 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cohere_rerank.py @@ -9,6 +9,7 @@ from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.utils import get_from_dict_or_env from pydantic import ConfigDict, model_validator +from typing_extensions import override @deprecated( @@ -98,6 +99,7 @@ class CohereRerank(BaseDocumentCompressor): for res in results ] + @override def compress_documents( self, documents: Sequence[Document], diff --git a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py index f786279eaf0..553811eb94c 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py +++ b/libs/langchain/langchain/retrievers/document_compressors/cross_encoder_rerank.py @@ -7,6 +7,7 @@ from typing import Optional from langchain_core.callbacks import Callbacks from langchain_core.documents import BaseDocumentCompressor, Document from pydantic import ConfigDict +from typing_extensions import override from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder @@ -25,6 +26,7 @@ class CrossEncoderReranker(BaseDocumentCompressor): extra="forbid", ) + @override def compress_documents( self, documents: Sequence[Document], diff --git a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py index 9c617344c5d..cc4c3098ef3 100644 --- a/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/libs/langchain/langchain/retrievers/document_compressors/embeddings_filter.py @@ -6,6 +6,7 @@ from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.embeddings import Embeddings from langchain_core.utils import pre_init from pydantic import ConfigDict, Field +from typing_extensions import override def _get_similarity_function() -> Callable: @@ -50,6 +51,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): raise ValueError(msg) return values + @override def compress_documents( self, documents: Sequence[Document], @@ -93,6 +95,7 @@ class EmbeddingsFilter(BaseDocumentCompressor): stateful_documents[i].state["query_similarity_score"] = similarity[i] return [stateful_documents[i] for i in included_idxs] + @override async def acompress_documents( self, documents: Sequence[Document], diff --git a/libs/langchain/langchain/retrievers/multi_query.py b/libs/langchain/langchain/retrievers/multi_query.py index cb7f734348b..4006684aa5d 100644 --- a/libs/langchain/langchain/retrievers/multi_query.py +++ b/libs/langchain/langchain/retrievers/multi_query.py @@ -67,7 +67,7 @@ class MultiQueryRetriever(BaseRetriever): retriever: BaseRetriever, llm: BaseLanguageModel, prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT, - parser_key: Optional[str] = None, + parser_key: Optional[str] = None, # noqa: ARG003 include_original: bool = False, # noqa: FBT001,FBT002 ) -> "MultiQueryRetriever": """Initialize from llm using default template. diff --git a/libs/langchain/langchain/retrievers/multi_vector.py b/libs/langchain/langchain/retrievers/multi_vector.py index 94e80cbe4d8..9e973e8190e 100644 --- a/libs/langchain/langchain/retrievers/multi_vector.py +++ b/libs/langchain/langchain/retrievers/multi_vector.py @@ -10,6 +10,7 @@ from langchain_core.retrievers import BaseRetriever from langchain_core.stores import BaseStore, ByteStore from langchain_core.vectorstores import VectorStore from pydantic import Field, model_validator +from typing_extensions import override from langchain.storage._lc_store import create_kv_docstore @@ -54,6 +55,7 @@ class MultiVectorRetriever(BaseRetriever): values["docstore"] = docstore return values + @override def _get_relevant_documents( self, query: str, @@ -91,6 +93,7 @@ class MultiVectorRetriever(BaseRetriever): docs = self.docstore.mget(ids) return [d for d in docs if d is not None] + @override async def _aget_relevant_documents( self, query: str, diff --git a/libs/langchain/langchain/retrievers/time_weighted_retriever.py b/libs/langchain/langchain/retrievers/time_weighted_retriever.py index 4bee5cedfd5..4087cce12fb 100644 --- a/libs/langchain/langchain/retrievers/time_weighted_retriever.py +++ b/libs/langchain/langchain/retrievers/time_weighted_retriever.py @@ -10,6 +10,7 @@ from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore from pydantic import ConfigDict, Field +from typing_extensions import override def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: @@ -128,6 +129,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): result.append(buffered_doc) return result + @override def _get_relevant_documents( self, query: str, @@ -142,6 +144,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever): docs_and_scores.update(self.get_salient_docs(query)) return self._get_rescored_docs(docs_and_scores) + @override async def _aget_relevant_documents( self, query: str, diff --git a/libs/langchain/langchain/smith/evaluation/progress.py b/libs/langchain/langchain/smith/evaluation/progress.py index 613a4f63073..4282f9c76ae 100644 --- a/libs/langchain/langchain/smith/evaluation/progress.py +++ b/libs/langchain/langchain/smith/evaluation/progress.py @@ -19,7 +19,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler): total: int, ncols: int = 50, end_with: str = "\n", - **kwargs: Any, ): """Initialize the progress bar. diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index b481c2a3fb0..0ec44e2d2f4 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -355,6 +355,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): feedback.evaluator_info[RUN_KEY] = output[RUN_KEY] return feedback + @override def evaluate_run( self, run: Run, @@ -372,6 +373,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): # TODO: Add run ID once we can declare it via callbacks ) + @override async def aevaluate_run( self, run: Run, diff --git a/libs/langchain/langchain/tools/python/__init__.py b/libs/langchain/langchain/tools/python/__init__.py index c92bc153fde..2a199c1b741 100644 --- a/libs/langchain/langchain/tools/python/__init__.py +++ b/libs/langchain/langchain/tools/python/__init__.py @@ -1,7 +1,7 @@ from typing import Any -def __getattr__(name: str = "") -> Any: +def __getattr__(_: str = "") -> Any: msg = ( "This tool has been moved to langchain experiment. " "This tool has access to a python REPL. " diff --git a/libs/langchain/pyproject.toml b/libs/langchain/pyproject.toml index c48fe751533..1ec57de2480 100644 --- a/libs/langchain/pyproject.toml +++ b/libs/langchain/pyproject.toml @@ -145,6 +145,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy [tool.ruff.lint] select = [ "A", # flake8-builtins + "ARG", # flake8-unused-arguments "ASYNC", # flake8-async "B", # flake8-bugbear "C4", # flake8-comprehensions diff --git a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py index 9f318b60cfe..f427e3c478e 100644 --- a/libs/langchain/tests/integration_tests/cache/fake_embeddings.py +++ b/libs/langchain/tests/integration_tests/cache/fake_embeddings.py @@ -3,6 +3,7 @@ import math from langchain_core.embeddings import Embeddings +from typing_extensions import override fake_texts = ["foo", "bar", "baz"] @@ -18,6 +19,7 @@ class FakeEmbeddings(Embeddings): async def aembed_documents(self, texts: list[str]) -> list[list[float]]: return self.embed_documents(texts) + @override def embed_query(self, text: str) -> list[float]: """Return constant query embeddings. Embeddings are identical to embed_documents(texts)[0]. diff --git a/libs/langchain/tests/unit_tests/agents/test_agent.py b/libs/langchain/tests/unit_tests/agents/test_agent.py index 6442e8e3dd8..ea6c455db69 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent.py @@ -25,6 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables.utils import add from langchain_core.tools import Tool, tool from langchain_core.tracers import RunLog, RunLogPatch +from typing_extensions import override from langchain.agents import ( AgentExecutor, @@ -48,6 +49,7 @@ class FakeListLLM(LLM): responses: list[str] i: int = -1 + @override def _call( self, prompt: str, @@ -462,7 +464,7 @@ async def test_runnable_agent() -> None: ], ) - def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]: """A parser.""" return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message") @@ -569,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None: ], ) - def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]: """A parser.""" return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) @@ -681,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None: ], ) - def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: + def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]: """A parser.""" return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) @@ -1032,7 +1034,7 @@ async def test_openai_agent_tools_agent() -> None: ], ) - GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc] + GenericFakeChatModel.bind_tools = lambda self, _: self # type: ignore[assignment,misc] model = GenericFakeChatModel(messages=infinite_cycle) @tool diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_async.py b/libs/langchain/tests/unit_tests/agents/test_agent_async.py index 9716ab270a6..e7cba51a3b1 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_async.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_async.py @@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM from langchain_core.messages import AIMessage, HumanMessage from langchain_core.runnables.utils import add from langchain_core.tools import Tool +from typing_extensions import override from langchain.agents import AgentExecutor, AgentType, initialize_agent from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -19,6 +20,7 @@ class FakeListLLM(LLM): responses: list[str] i: int = -1 + @override def _call( self, prompt: str, diff --git a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py index 95e466701af..b061604312d 100644 --- a/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py +++ b/libs/langchain/tests/unit_tests/agents/test_agent_iterator.py @@ -364,7 +364,7 @@ def test_agent_iterator_failing_tool() -> None: tools = [ Tool( name="FailingTool", - func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError + func=lambda _: 1 / 0, # This tool will raise a ZeroDivisionError description="A tool that fails", ), ] diff --git a/libs/langchain/tests/unit_tests/agents/test_initialize.py b/libs/langchain/tests/unit_tests/agents/test_initialize.py index 39473af4ad0..f6ec7ead06c 100644 --- a/libs/langchain/tests/unit_tests/agents/test_initialize.py +++ b/libs/langchain/tests/unit_tests/agents/test_initialize.py @@ -8,7 +8,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM @tool -def my_tool(query: str) -> str: +def my_tool(query: str) -> str: # noqa: ARG001 """A fake tool.""" return "fake tool" diff --git a/libs/langchain/tests/unit_tests/agents/test_mrkl.py b/libs/langchain/tests/unit_tests/agents/test_mrkl.py index 3bb5678d5c1..b1565925648 100644 --- a/libs/langchain/tests/unit_tests/agents/test_mrkl.py +++ b/libs/langchain/tests/unit_tests/agents/test_mrkl.py @@ -141,11 +141,11 @@ def test_valid_action_and_answer_raises_exception() -> None: def test_from_chains() -> None: """Test initializing from chains.""" chain_configs = [ - Tool(name="foo", func=lambda x: "foo", description="foobar1"), - Tool(name="bar", func=lambda x: "bar", description="foobar2"), + Tool(name="foo", func=lambda _x: "foo", description="foobar1"), + Tool(name="bar", func=lambda _x: "bar", description="foobar2"), ] agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs) - expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2" + expected_tools_prompt = "foo(_x) - foobar1\nbar(_x) - foobar2" expected_tool_names = "foo, bar" expected_template = "\n\n".join( [ diff --git a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py index f29674672e8..8304295ac91 100644 --- a/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py +++ b/libs/langchain/tests/unit_tests/agents/test_openai_assistant.py @@ -7,7 +7,7 @@ import pytest from langchain.agents.openai_assistant import OpenAIAssistantRunnable -def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any: +def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any: client = AsyncMock() if use_async else MagicMock() mock_assistant = MagicMock() mock_assistant.id = "abc123" diff --git a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py index 16733d25cb3..8b97e226ec7 100644 --- a/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py +++ b/libs/langchain/tests/unit_tests/callbacks/fake_callback_handler.py @@ -7,6 +7,7 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.messages import BaseMessage from pydantic import BaseModel +from typing_extensions import override class BaseFakeCallbackHandler(BaseModel): @@ -135,6 +136,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): """Whether to ignore retriever callbacks.""" return self.ignore_retriever_ + @override def on_llm_start( self, *args: Any, @@ -142,6 +144,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_start_common() + @override def on_llm_new_token( self, *args: Any, @@ -149,6 +152,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_new_token_common() + @override def on_llm_end( self, *args: Any, @@ -156,6 +160,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_end_common() + @override def on_llm_error( self, *args: Any, @@ -163,6 +168,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_llm_error_common() + @override def on_retry( self, *args: Any, @@ -170,6 +176,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retry_common() + @override def on_chain_start( self, *args: Any, @@ -177,6 +184,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_start_common() + @override def on_chain_end( self, *args: Any, @@ -184,6 +192,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_end_common() + @override def on_chain_error( self, *args: Any, @@ -191,6 +200,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_chain_error_common() + @override def on_tool_start( self, *args: Any, @@ -198,6 +208,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_start_common() + @override def on_tool_end( self, *args: Any, @@ -205,6 +216,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_end_common() + @override def on_tool_error( self, *args: Any, @@ -212,6 +224,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_tool_error_common() + @override def on_agent_action( self, *args: Any, @@ -219,6 +232,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_agent_action_common() + @override def on_agent_finish( self, *args: Any, @@ -226,6 +240,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_agent_finish_common() + @override def on_text( self, *args: Any, @@ -233,6 +248,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_text_common() + @override def on_retriever_start( self, *args: Any, @@ -240,6 +256,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_start_common() + @override def on_retriever_end( self, *args: Any, @@ -247,6 +264,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): ) -> Any: self.on_retriever_end_common() + @override def on_retriever_error( self, *args: Any, @@ -259,6 +277,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): + @override def on_chat_model_start( self, serialized: dict[str, Any], @@ -290,6 +309,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi """Whether to ignore agent callbacks.""" return self.ignore_agent_ + @override async def on_retry( self, *args: Any, @@ -297,6 +317,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> Any: self.on_retry_common() + @override async def on_llm_start( self, *args: Any, @@ -304,6 +325,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_start_common() + @override async def on_llm_new_token( self, *args: Any, @@ -311,6 +333,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_new_token_common() + @override async def on_llm_end( self, *args: Any, @@ -318,6 +341,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_end_common() + @override async def on_llm_error( self, *args: Any, @@ -325,6 +349,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_llm_error_common() + @override async def on_chain_start( self, *args: Any, @@ -332,6 +357,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_start_common() + @override async def on_chain_end( self, *args: Any, @@ -339,6 +365,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_end_common() + @override async def on_chain_error( self, *args: Any, @@ -346,6 +373,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_chain_error_common() + @override async def on_tool_start( self, *args: Any, @@ -353,6 +381,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_start_common() + @override async def on_tool_end( self, *args: Any, @@ -360,6 +389,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_end_common() + @override async def on_tool_error( self, *args: Any, @@ -367,6 +397,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_tool_error_common() + @override async def on_agent_action( self, *args: Any, @@ -374,6 +405,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_agent_action_common() + @override async def on_agent_finish( self, *args: Any, @@ -381,6 +413,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi ) -> None: self.on_agent_finish_common() + @override async def on_text( self, *args: Any, diff --git a/libs/langchain/tests/unit_tests/callbacks/test_file.py b/libs/langchain/tests/unit_tests/callbacks/test_file.py index fff64bd93a4..7e91bd15266 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_file.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_file.py @@ -3,6 +3,7 @@ import re from typing import Optional from langchain_core.callbacks import CallbackManagerForChainRun +from typing_extensions import override from langchain.callbacks import FileCallbackHandler from langchain.chains.base import Chain @@ -25,6 +26,7 @@ class FakeChain(Chain): """Output key of bar.""" return self.the_output_keys + @override def _call( self, inputs: dict[str, str], diff --git a/libs/langchain/tests/unit_tests/callbacks/test_stdout.py b/libs/langchain/tests/unit_tests/callbacks/test_stdout.py index 4c35ce3ff2f..a19fc6cf4e6 100644 --- a/libs/langchain/tests/unit_tests/callbacks/test_stdout.py +++ b/libs/langchain/tests/unit_tests/callbacks/test_stdout.py @@ -2,6 +2,7 @@ from typing import Any, Optional import pytest from langchain_core.callbacks import CallbackManagerForChainRun +from typing_extensions import override from langchain.callbacks import StdOutCallbackHandler from langchain.chains.base import Chain @@ -24,6 +25,7 @@ class FakeChain(Chain): """Output key of bar.""" return self.the_output_keys + @override def _call( self, inputs: dict[str, str], diff --git a/libs/langchain/tests/unit_tests/chains/test_base.py b/libs/langchain/tests/unit_tests/chains/test_base.py index 0e47902bbee..a40f0c9f6f2 100644 --- a/libs/langchain/tests/unit_tests/chains/test_base.py +++ b/libs/langchain/tests/unit_tests/chains/test_base.py @@ -7,6 +7,7 @@ import pytest from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.memory import BaseMemory from langchain_core.tracers.context import collect_runs +from typing_extensions import override from langchain.chains.base import Chain from langchain.schema import RUN_KEY @@ -21,6 +22,7 @@ class FakeMemory(BaseMemory): """Return baz variable.""" return ["baz"] + @override def load_memory_variables( self, inputs: Optional[dict[str, Any]] = None, @@ -52,6 +54,7 @@ class FakeChain(Chain): """Output key of bar.""" return self.the_output_keys + @override def _call( self, inputs: dict[str, str], diff --git a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py index ff08295179e..486e48bf1e7 100644 --- a/libs/langchain/tests/unit_tests/chains/test_combine_documents.py +++ b/libs/langchain/tests/unit_tests/chains/test_combine_documents.py @@ -19,7 +19,7 @@ def _fake_docs_len_func(docs: list[Document]) -> int: return len(_fake_combine_docs_func(docs)) -def _fake_combine_docs_func(docs: list[Document], **kwargs: Any) -> str: +def _fake_combine_docs_func(docs: list[Document], **_: Any) -> str: return "".join([d.page_content for d in docs]) diff --git a/libs/langchain/tests/unit_tests/chains/test_conversation.py b/libs/langchain/tests/unit_tests/chains/test_conversation.py index ec494d85088..9a7749f36f8 100644 --- a/libs/langchain/tests/unit_tests/chains/test_conversation.py +++ b/libs/langchain/tests/unit_tests/chains/test_conversation.py @@ -8,6 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models import LLM from langchain_core.memory import BaseMemory from langchain_core.prompts.prompt import PromptTemplate +from typing_extensions import override from langchain.chains.conversation.base import ConversationChain from langchain.memory.buffer import ConversationBufferMemory @@ -26,6 +27,7 @@ class DummyLLM(LLM): def _llm_type(self) -> str: return "dummy" + @override def _call( self, prompt: str, diff --git a/libs/langchain/tests/unit_tests/chains/test_hyde.py b/libs/langchain/tests/unit_tests/chains/test_hyde.py index 52e210a89a3..ac9894c0045 100644 --- a/libs/langchain/tests/unit_tests/chains/test_hyde.py +++ b/libs/langchain/tests/unit_tests/chains/test_hyde.py @@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import ( from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, LLMResult +from typing_extensions import override from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.prompts import PROMPT_MAP @@ -18,10 +19,12 @@ from langchain.chains.hyde.prompts import PROMPT_MAP class FakeEmbeddings(Embeddings): """Fake embedding class for tests.""" + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: """Return random floats.""" return [list(np.random.uniform(0, 1, 10)) for _ in range(10)] + @override def embed_query(self, text: str) -> list[float]: """Return random floats.""" return list(np.random.uniform(0, 1, 10)) @@ -32,6 +35,7 @@ class FakeLLM(BaseLLM): n: int = 1 + @override def _generate( self, prompts: list[str], @@ -41,6 +45,7 @@ class FakeLLM(BaseLLM): ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) + @override async def _agenerate( self, prompts: list[str], diff --git a/libs/langchain/tests/unit_tests/chains/test_sequential.py b/libs/langchain/tests/unit_tests/chains/test_sequential.py index f3313e9139f..aba6df37616 100644 --- a/libs/langchain/tests/unit_tests/chains/test_sequential.py +++ b/libs/langchain/tests/unit_tests/chains/test_sequential.py @@ -8,6 +8,7 @@ from langchain_core.callbacks.manager import ( AsyncCallbackManagerForChainRun, CallbackManagerForChainRun, ) +from typing_extensions import override from langchain.chains.base import Chain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain @@ -32,6 +33,7 @@ class FakeChain(Chain): """Input keys this chain returns.""" return self.output_variables + @override def _call( self, inputs: dict[str, str], @@ -43,6 +45,7 @@ class FakeChain(Chain): outputs[var] = f"{' '.join(variables)}foo" return outputs + @override async def _acall( self, inputs: dict[str, str], diff --git a/libs/langchain/tests/unit_tests/document_loaders/test_base.py b/libs/langchain/tests/unit_tests/document_loaders/test_base.py index ec560b853d8..f09ada756f4 100644 --- a/libs/langchain/tests/unit_tests/document_loaders/test_base.py +++ b/libs/langchain/tests/unit_tests/document_loaders/test_base.py @@ -4,6 +4,7 @@ from collections.abc import Iterator from langchain_core.document_loaders import BaseBlobParser, Blob from langchain_core.documents import Document +from typing_extensions import override def test_base_blob_parser() -> None: @@ -12,6 +13,7 @@ def test_base_blob_parser() -> None: class MyParser(BaseBlobParser): """A simple parser that returns a single document.""" + @override def lazy_parse(self, blob: Blob) -> Iterator[Document]: """Lazy parsing interface.""" yield Document( diff --git a/libs/langchain/tests/unit_tests/embeddings/test_caching.py b/libs/langchain/tests/unit_tests/embeddings/test_caching.py index b4ab2c93f3d..4dd139ed1a9 100644 --- a/libs/langchain/tests/unit_tests/embeddings/test_caching.py +++ b/libs/langchain/tests/unit_tests/embeddings/test_caching.py @@ -7,12 +7,14 @@ import warnings import pytest from langchain_core.embeddings import Embeddings +from typing_extensions import override from langchain.embeddings import CacheBackedEmbeddings from langchain.storage.in_memory import InMemoryStore class MockEmbeddings(Embeddings): + @override def embed_documents(self, texts: list[str]) -> list[list[float]]: # Simulate embedding documents embeddings: list[list[float]] = [] @@ -23,6 +25,7 @@ class MockEmbeddings(Embeddings): embeddings.append([len(text), len(text) + 1]) return embeddings + @override def embed_query(self, text: str) -> list[float]: # Simulate embedding a query return [5.0, 6.0] diff --git a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py index 50deade41d3..02f091b3256 100644 --- a/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py +++ b/libs/langchain/tests/unit_tests/evaluation/agents/test_eval_chain.py @@ -9,6 +9,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.messages import BaseMessage from langchain_core.tools import tool from pydantic import Field +from typing_extensions import override from langchain.evaluation.agents.trajectory_eval_chain import ( TrajectoryEval, @@ -43,6 +44,7 @@ class _FakeTrajectoryChatModel(FakeChatModel): sequential_responses: Optional[bool] = False response_index: int = 0 + @override def _call( self, messages: list[BaseMessage], diff --git a/libs/langchain/tests/unit_tests/indexes/test_indexing.py b/libs/langchain/tests/unit_tests/indexes/test_indexing.py index e5954a5e374..723fff342cd 100644 --- a/libs/langchain/tests/unit_tests/indexes/test_indexing.py +++ b/libs/langchain/tests/unit_tests/indexes/test_indexing.py @@ -13,6 +13,7 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.indexing.api import _abatch, _get_document_with_hash from langchain_core.vectorstores import VST, VectorStore +from typing_extensions import override from langchain.indexes import aindex, index from langchain.indexes._sql_record_manager import SQLRecordManager @@ -45,18 +46,21 @@ class InMemoryVectorStore(VectorStore): self.store: dict[str, Document] = {} self.permit_upserts = permit_upserts + @override def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: """Delete the given documents from the store using their IDs.""" if ids: for _id in ids: self.store.pop(_id, None) + @override async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: """Delete the given documents from the store using their IDs.""" if ids: for _id in ids: self.store.pop(_id, None) + @override def add_documents( self, documents: Sequence[Document], @@ -81,6 +85,7 @@ class InMemoryVectorStore(VectorStore): return list(ids) + @override async def aadd_documents( self, documents: Sequence[Document], diff --git a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py index aa59640829d..f11141a01d1 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/fake_chat_model.py @@ -16,11 +16,13 @@ from langchain_core.messages import ( ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import run_in_executor +from typing_extensions import override class FakeChatModel(SimpleChatModel): """Fake Chat Model wrapper for testing purposes.""" + @override def _call( self, messages: list[BaseMessage], @@ -30,6 +32,7 @@ class FakeChatModel(SimpleChatModel): ) -> str: return "fake response" + @override async def _agenerate( self, messages: list[BaseMessage], @@ -74,6 +77,7 @@ class GenericFakeChatModel(BaseChatModel): into message chunks. """ + @override def _generate( self, messages: list[BaseMessage], diff --git a/libs/langchain/tests/unit_tests/llms/fake_llm.py b/libs/langchain/tests/unit_tests/llms/fake_llm.py index 61efe09cc2e..f4033744255 100644 --- a/libs/langchain/tests/unit_tests/llms/fake_llm.py +++ b/libs/langchain/tests/unit_tests/llms/fake_llm.py @@ -6,6 +6,7 @@ from typing import Any, Optional, cast from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from pydantic import model_validator +from typing_extensions import override class FakeLLM(LLM): @@ -32,6 +33,7 @@ class FakeLLM(LLM): """Return type of llm.""" return "fake" + @override def _call( self, prompt: str, diff --git a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py index 9b27f3e1196..e5e8de87f0f 100644 --- a/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py +++ b/libs/langchain/tests/unit_tests/llms/test_fake_chat_model.py @@ -7,6 +7,7 @@ from uuid import UUID from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.outputs import ChatGenerationChunk, GenerationChunk +from typing_extensions import override from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk @@ -166,6 +167,7 @@ async def test_callback_handlers() -> None: # Required to implement since this is an abstract method pass + @override async def on_llm_new_token( self, token: str, diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py index 9b914bce786..ebb56956f3f 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_fix.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_fix.py @@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts.prompt import PromptTemplate from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from typing_extensions import override from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser @@ -21,6 +22,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]): parse_count: int = 0 # Number of times parse has been called attemp_count_before_success: int # Number of times to fail before succeeding + @override def parse(self, *args: Any, **kwargs: Any) -> str: self.parse_count += 1 if self.parse_count <= self.attemp_count_before_success: @@ -62,7 +64,7 @@ def test_output_fixing_parser_parse( def test_output_fixing_parser_from_llm() -> None: - def fake_llm(prompt: str) -> AIMessage: + def fake_llm(_: str) -> AIMessage: return AIMessage("2024-07-08T00:00:00.000000Z") llm = RunnableLambda(fake_llm) diff --git a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py index 4dc8cbac55f..12b436f73b1 100644 --- a/libs/langchain/tests/unit_tests/output_parsers/test_retry.py +++ b/libs/langchain/tests/unit_tests/output_parsers/test_retry.py @@ -7,6 +7,7 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough +from typing_extensions import override from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser @@ -25,6 +26,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]): attemp_count_before_success: int # Number of times to fail before succeeding error_msg: str = "error" + @override def parse(self, *args: Any, **kwargs: Any) -> str: self.parse_count += 1 if self.parse_count <= self.attemp_count_before_success: diff --git a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py index 17793c9fc1d..b50c50b2c61 100644 --- a/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py +++ b/libs/langchain/tests/unit_tests/retrievers/self_query/test_base.py @@ -14,6 +14,7 @@ from langchain_core.structured_query import ( StructuredQuery, Visitor, ) +from typing_extensions import override from langchain.chains.query_constructor.schema import AttributeInfo from langchain.retrievers import SelfQueryRetriever @@ -61,6 +62,7 @@ class FakeTranslator(Visitor): class InMemoryVectorstoreWithSearch(InMemoryVectorStore): + @override def similarity_search( self, query: str, diff --git a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py index 6452caff9f4..e47e144edee 100644 --- a/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/sequential_retriever.py @@ -1,5 +1,8 @@ +from typing import Any + from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever +from typing_extensions import override class SequentialRetriever(BaseRetriever): @@ -8,17 +11,21 @@ class SequentialRetriever(BaseRetriever): sequential_responses: list[list[Document]] response_index: int = 0 - def _get_relevant_documents( # type: ignore[override] + @override + def _get_relevant_documents( self, query: str, + **kwargs: Any, ) -> list[Document]: if self.response_index >= len(self.sequential_responses): return [] self.response_index += 1 return self.sequential_responses[self.response_index - 1] - async def _aget_relevant_documents( # type: ignore[override] + @override + async def _aget_relevant_documents( self, query: str, + **kwargs: Any, ) -> list[Document]: return self._get_relevant_documents(query) diff --git a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py index 225fb469aab..02a67e33a73 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_ensemble.py @@ -3,6 +3,7 @@ from typing import Optional from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever +from typing_extensions import override from langchain.retrievers.ensemble import EnsembleRetriever @@ -10,6 +11,7 @@ from langchain.retrievers.ensemble import EnsembleRetriever class MockRetriever(BaseRetriever): docs: list[Document] + @override def _get_relevant_documents( self, query: str, diff --git a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py index 3c5d282e0ab..b908f0f7ac0 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_multi_vector.py @@ -1,6 +1,7 @@ from typing import Any, Callable from langchain_core.documents import Document +from typing_extensions import override from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType from langchain.storage import InMemoryStore @@ -15,6 +16,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): def _select_relevance_score_fn(self) -> Callable[[float], float]: return self._identity_fn + @override def similarity_search( self, query: str, @@ -26,6 +28,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return [] return [res] + @override def similarity_search_with_score( self, query: str, diff --git a/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py b/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py index a8626a4d248..5036ef40259 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_parent_document.py @@ -3,6 +3,7 @@ from typing import Any from langchain_core.documents import Document from langchain_text_splitters.character import CharacterTextSplitter +from typing_extensions import override from langchain.retrievers import ParentDocumentRetriever from langchain.storage import InMemoryStore @@ -10,6 +11,7 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore class InMemoryVectorstoreWithSearch(InMemoryVectorStore): + @override def similarity_search( self, query: str, @@ -21,6 +23,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore): return [] return [res] + @override def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]: print(documents) # noqa: T201 return super().add_documents( diff --git a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py index f400b7a738e..19d9379a332 100644 --- a/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py +++ b/libs/langchain/tests/unit_tests/retrievers/test_time_weighted_retriever.py @@ -8,6 +8,7 @@ import pytest from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore +from typing_extensions import override from langchain.retrievers.time_weighted_retriever import ( TimeWeightedVectorStoreRetriever, @@ -31,6 +32,7 @@ def _get_example_memories(k: int = 4) -> list[Document]: class MockVectorStore(VectorStore): """Mock invalid vector store.""" + @override def add_texts( self, texts: Iterable[str], @@ -39,6 +41,7 @@ class MockVectorStore(VectorStore): ) -> list[str]: return list(texts) + @override def similarity_search( self, query: str, @@ -48,6 +51,7 @@ class MockVectorStore(VectorStore): return [] @classmethod + @override def from_texts( cls: type["MockVectorStore"], texts: list[str], @@ -57,6 +61,7 @@ class MockVectorStore(VectorStore): ) -> "MockVectorStore": return cls() + @override def _similarity_search_with_relevance_scores( self, query: str, diff --git a/libs/langchain/tests/unit_tests/runnables/test_hub.py b/libs/langchain/tests/unit_tests/runnables/test_hub.py index 43915bdeeff..ad2ffb4b245 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_hub.py +++ b/libs/langchain/tests/unit_tests/runnables/test_hub.py @@ -38,7 +38,7 @@ repo_dict = { } -def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate: +def repo_lookup(owner_repo_commit: str, **_: Any) -> ChatPromptTemplate: return repo_dict[owner_repo_commit] diff --git a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py index 646594f0f26..5579791e063 100644 --- a/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py +++ b/libs/langchain/tests/unit_tests/runnables/test_openai_functions.py @@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, BaseMessage from langchain_core.outputs import ChatGeneration, ChatResult from pytest_mock import MockerFixture from syrupy.assertion import SnapshotAssertion +from typing_extensions import override from langchain.runnables.openai_functions import OpenAIFunctionsRouter @@ -15,6 +16,7 @@ class FakeChatOpenAI(BaseChatModel): def _llm_type(self) -> str: return "fake-openai-chat-model" + @override def _generate( self, messages: list[BaseMessage], diff --git a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py index 05f6d691092..c35ae88893e 100644 --- a/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py +++ b/libs/langchain/tests/unit_tests/smith/evaluation/test_runner_utils.py @@ -3,16 +3,14 @@ import uuid from collections.abc import Iterator from datetime import datetime, timezone -from typing import Any, Optional, Union +from typing import Any from unittest import mock import pytest from freezegun import freeze_time -from langchain_core.language_models import BaseLanguageModel from langsmith.client import Client from langsmith.schemas import Dataset, Example -from langchain.chains.base import Chain from langchain.chains.transform import TransformChain from langchain.smith.evaluation.runner_utils import ( InputFormatError, @@ -243,7 +241,7 @@ def test_run_chat_model_all_formats(inputs: dict[str, Any]) -> None: @freeze_time("2023-01-01") -async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_arun_on_dataset() -> None: dataset = Dataset( id=uuid.uuid4(), name="test", @@ -298,22 +296,20 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: ), ] - def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset: + def mock_read_dataset(*_: Any, **__: Any) -> Dataset: return dataset - def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]: + def mock_list_examples(*_: Any, **__: Any) -> Iterator[Example]: return iter(examples) async def mock_arun_chain( example: Example, - llm_or_chain: Union[BaseLanguageModel, Chain], - tags: Optional[list[str]] = None, - callbacks: Optional[Any] = None, - **kwargs: Any, + *_: Any, + **__: Any, ) -> dict[str, Any]: return {"result": f"Result for example {example.id}"} - def mock_create_project(*args: Any, **kwargs: Any) -> Any: + def mock_create_project(*_: Any, **__: Any) -> Any: proj = mock.MagicMock() proj.id = "123" return proj diff --git a/libs/langchain/tests/unit_tests/tools/test_render.py b/libs/langchain/tests/unit_tests/tools/test_render.py index 66df360c19d..97a36599332 100644 --- a/libs/langchain/tests/unit_tests/tools/test_render.py +++ b/libs/langchain/tests/unit_tests/tools/test_render.py @@ -8,13 +8,13 @@ from langchain.tools.render import ( @tool -def search(query: str) -> str: +def search(query: str) -> str: # noqa: ARG001 """Lookup things online.""" return "foo" @tool -def calculator(expression: str) -> str: +def calculator(expression: str) -> str: # noqa: ARG001 """Do math.""" return "bar"