chore(langchain): add ruff rules ARG (#32110)

See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-27 00:32:34 +02:00 committed by GitHub
parent a2ad5aca41
commit efdfa00d10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
79 changed files with 241 additions and 62 deletions

View File

@ -116,8 +116,8 @@ class BaseSingleActionAgent(BaseModel):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
**kwargs: Any, **_: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -125,7 +125,6 @@ class BaseSingleActionAgent(BaseModel):
early_stopping_method: Method to use for early stopping. early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations. along with observations.
**kwargs: User inputs.
Returns: Returns:
AgentFinish: Agent finish object. AgentFinish: Agent finish object.
@ -168,6 +167,7 @@ class BaseSingleActionAgent(BaseModel):
"""Return Identifier of an agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
@override
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent. """Return dictionary representation of agent.
@ -289,8 +289,8 @@ class BaseMultiActionAgent(BaseModel):
def return_stopped_response( def return_stopped_response(
self, self,
early_stopping_method: str, early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]], intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
**kwargs: Any, **_: Any,
) -> AgentFinish: ) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations. """Return response when agent has been stopped due to max iterations.
@ -298,7 +298,6 @@ class BaseMultiActionAgent(BaseModel):
early_stopping_method: Method to use for early stopping. early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date, intermediate_steps: Steps the LLM has taken to date,
along with observations. along with observations.
**kwargs: User inputs.
Returns: Returns:
AgentFinish: Agent finish object. AgentFinish: Agent finish object.
@ -317,6 +316,7 @@ class BaseMultiActionAgent(BaseModel):
"""Return Identifier of an agent type.""" """Return Identifier of an agent type."""
raise NotImplementedError raise NotImplementedError
@override
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().model_dump() _dict = super().model_dump()
@ -651,6 +651,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
""" """
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"}) return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
@override
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
@ -735,6 +736,7 @@ class Agent(BaseSingleActionAgent):
allowed_tools: Optional[list[str]] = None allowed_tools: Optional[list[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed.""" """Allowed tools for the agent. If None, all tools are allowed."""
@override
def dict(self, **kwargs: Any) -> builtins.dict: def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.""" """Return dictionary representation of agent."""
_dict = super().dict() _dict = super().dict()
@ -750,18 +752,6 @@ class Agent(BaseSingleActionAgent):
"""Return values of the agent.""" """Return values of the agent."""
return ["output"] return ["output"]
def _fix_text(self, text: str) -> str:
"""Fix the text.
Args:
text: Text to fix.
Returns:
str: Fixed text.
"""
msg = "fix_text not implemented for this agent."
raise ValueError(msg)
@property @property
def _stop(self) -> list[str]: def _stop(self) -> list[str]:
return [ return [
@ -1021,6 +1011,7 @@ class ExceptionTool(BaseTool):
description: str = "Exception tool" description: str = "Exception tool"
"""Description of the tool.""" """Description of the tool."""
@override
def _run( def _run(
self, self,
query: str, query: str,
@ -1028,6 +1019,7 @@ class ExceptionTool(BaseTool):
) -> str: ) -> str:
return query return query
@override
async def _arun( async def _arun(
self, self,
query: str, query: str,
@ -1188,6 +1180,7 @@ class AgentExecutor(Chain):
return cast("RunnableAgentType", self.agent) return cast("RunnableAgentType", self.agent)
return self.agent return self.agent
@override
def save(self, file_path: Union[Path, str]) -> None: def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors. """Raise error - saving not supported for Agent Executors.
@ -1218,7 +1211,7 @@ class AgentExecutor(Chain):
callbacks: Callbacks = None, callbacks: Callbacks = None,
*, *,
include_run_info: bool = False, include_run_info: bool = False,
async_: bool = False, # arg kept for backwards compat, but ignored async_: bool = False, # noqa: ARG002 arg kept for backwards compat, but ignored
) -> AgentExecutorIterator: ) -> AgentExecutorIterator:
"""Enables iteration over steps taken to reach final output. """Enables iteration over steps taken to reach final output.

View File

@ -13,6 +13,7 @@ from langchain_core.prompts.chat import (
) )
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent import Agent, AgentOutputParser
@ -65,6 +66,7 @@ class ChatAgent(Agent):
return agent_scratchpad return agent_scratchpad
@classmethod @classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ChatOutputParser() return ChatOutputParser()

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent import Agent, AgentOutputParser
@ -35,6 +36,7 @@ class ConversationalAgent(Agent):
"""Output parser for the agent.""" """Output parser for the agent."""
@classmethod @classmethod
@override
def _get_default_output_parser( def _get_default_output_parser(
cls, cls,
ai_prefix: str = "AI", ai_prefix: str = "AI",

View File

@ -20,6 +20,7 @@ from langchain_core.prompts.chat import (
) )
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain.agents.agent import Agent, AgentOutputParser from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
@ -42,6 +43,7 @@ class ConversationalChatAgent(Agent):
"""Template for the tool response.""" """Template for the tool response."""
@classmethod @classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ConvoOutputParser() return ConvoOutputParser()

View File

@ -12,6 +12,7 @@ from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool, Tool from langchain_core.tools import BaseTool, Tool
from langchain_core.tools.render import render_text_description from langchain_core.tools.render import render_text_description
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
@ -51,6 +52,7 @@ class ZeroShotAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser) output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
@classmethod @classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return MRKLOutputParser() return MRKLOutputParser()

View File

@ -4,6 +4,7 @@ from typing import Any
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.agents.format_scratchpad import ( from langchain.agents.format_scratchpad import (
format_to_openai_function_messages, format_to_openai_function_messages,
@ -55,6 +56,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer. """Return history buffer.

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
from langchain_core.tools import BaseTool, Tool from langchain_core.tools import BaseTool, Tool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser) output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
@classmethod @classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ReActOutputParser() return ReActOutputParser()
@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
return AgentType.REACT_DOCSTORE return AgentType.REACT_DOCSTORE
@classmethod @classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt.""" """Return default prompt."""
return WIKI_PROMPT return WIKI_PROMPT
@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain.""" """Agent for the ReAct TextWorld chain."""
@classmethod @classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt.""" """Return default prompt."""
return TEXTWORLD_PROMPT return TEXTWORLD_PROMPT

View File

@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool, Tool from langchain_core.tools import BaseTool, Tool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType from langchain.agents.agent_types import AgentType
@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser) output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
@classmethod @classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser: def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return SelfAskOutputParser() return SelfAskOutputParser()
@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
return AgentType.SELF_ASK_WITH_SEARCH return AgentType.SELF_ASK_WITH_SEARCH
@classmethod @classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate: def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Prompt does not depend on tools.""" """Prompt does not depend on tools."""
return PROMPT return PROMPT

View File

@ -71,6 +71,7 @@ class StructuredChatAgent(Agent):
pass pass
@classmethod @classmethod
@override
def _get_default_output_parser( def _get_default_output_parser(
cls, cls,
llm: Optional[BaseLanguageModel] = None, llm: Optional[BaseLanguageModel] = None,

View File

@ -7,6 +7,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.tools import BaseTool, tool from langchain_core.tools import BaseTool, tool
from typing_extensions import override
class InvalidTool(BaseTool): class InvalidTool(BaseTool):
@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
description: str = "Called when tool name is invalid. Suggests valid tool names." description: str = "Called when tool name is invalid. Suggests valid tool names."
"""Description of the tool.""" """Description of the tool."""
@override
def _run( def _run(
self, self,
requested_tool_name: str, requested_tool_name: str,
@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
f"try one of [{available_tool_names_str}]." f"try one of [{available_tool_names_str}]."
) )
@override
async def _arun( async def _arun(
self, self,
requested_tool_name: str, requested_tool_name: str,

View File

@ -4,6 +4,7 @@ import sys
from typing import Any, Optional from typing import Any, Optional
from langchain_core.callbacks import StreamingStdOutCallbackHandler from langchain_core.callbacks import StreamingStdOutCallbackHandler
from typing_extensions import override
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
@ -63,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.stream_prefix = stream_prefix self.stream_prefix = stream_prefix
self.answer_reached = False self.answer_reached = False
@override
def on_llm_start( def on_llm_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -72,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
"""Run when LLM starts running.""" """Run when LLM starts running."""
self.answer_reached = False self.answer_reached = False
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled.""" """Run on new LLM token. Only available when streaming is enabled."""

View File

@ -388,7 +388,7 @@ except ImportError:
class APIChain: # type: ignore[no-redef] class APIChain: # type: ignore[no-redef]
"""Raise an ImportError if APIChain is used without langchain_community.""" """Raise an ImportError if APIChain is used without langchain_community."""
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *_: Any, **__: Any) -> None:
"""Raise an ImportError if APIChain is used without langchain_community.""" """Raise an ImportError if APIChain is used without langchain_community."""
msg = ( msg = (
"To use the APIChain, you must install the langchain_community package." "To use the APIChain, you must install the langchain_community package."

View File

@ -83,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
""" """
return [self.output_key] return [self.output_key]
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
"""Return the prompt length given the documents passed in. """Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list This can be used by a caller to determine whether passing in a list

View File

@ -402,6 +402,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
return docs[:num_docs] return docs[:num_docs]
@override
def _get_docs( def _get_docs(
self, self,
question: str, question: str,
@ -416,6 +417,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
) )
return self._reduce_tokens_below_limit(docs) return self._reduce_tokens_below_limit(docs)
@override
async def _aget_docs( async def _aget_docs(
self, self,
question: str, question: str,
@ -512,6 +514,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
) )
return values return values
@override
def _get_docs( def _get_docs(
self, self,
question: str, question: str,

View File

@ -1,4 +1,4 @@
def __getattr__(name: str = "") -> None: def __getattr__(_: str = "") -> None:
"""Raise an error on import since is deprecated.""" """Raise an error on import since is deprecated."""
msg = ( msg = (
"This module has been moved to langchain-experimental. " "This module has been moved to langchain-experimental. "

View File

@ -1,4 +1,4 @@
def __getattr__(name: str = "") -> None: def __getattr__(_: str = "") -> None:
"""Raise an error on import since is deprecated.""" """Raise an error on import since is deprecated."""
msg = ( msg = (
"This module has been moved to langchain-experimental. " "This module has been moved to langchain-experimental. "

View File

@ -39,7 +39,7 @@ try:
from langchain_community.llms.loading import load_llm, load_llm_from_config from langchain_community.llms.loading import load_llm, load_llm_from_config
except ImportError: except ImportError:
def load_llm(*args: Any, **kwargs: Any) -> None: def load_llm(*_: Any, **__: Any) -> None:
"""Import error for load_llm.""" """Import error for load_llm."""
msg = ( msg = (
"To use this load_llm functionality you must install the " "To use this load_llm functionality you must install the "
@ -48,7 +48,7 @@ except ImportError:
) )
raise ImportError(msg) raise ImportError(msg)
def load_llm_from_config(*args: Any, **kwargs: Any) -> None: def load_llm_from_config(*_: Any, **__: Any) -> None:
"""Import error for load_llm_from_config.""" """Import error for load_llm_from_config."""
msg = ( msg = (
"To use this load_llm_from_config functionality you must install the " "To use this load_llm_from_config functionality you must install the "

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks import (
) )
from langchain_core.utils import check_package_version, get_from_dict_or_env from langchain_core.utils import check_package_version, get_from_dict_or_env
from pydantic import Field, model_validator from pydantic import Field, model_validator
from typing_extensions import override
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -105,6 +106,7 @@ class OpenAIModerationChain(Chain):
return error_str return error_str
return text return text
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],

View File

@ -16,6 +16,7 @@ from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, model_validator from pydantic import ConfigDict, model_validator
from typing_extensions import override
from langchain.chains import ReduceDocumentsChain from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
""" """
return [self.input_docs_key, self.question_key] return [self.input_docs_key, self.question_key]
@override
def _get_docs( def _get_docs(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""Get docs to run questioning over.""" """Get docs to run questioning over."""
return inputs.pop(self.input_docs_key) return inputs.pop(self.input_docs_key)
@override
async def _aget_docs( async def _aget_docs(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks import (
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator from pydantic import Field, model_validator
from typing_extensions import override
from langchain.chains.combine_documents.stuff import StuffDocumentsChain from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs] return docs[:num_docs]
@override
def _get_docs( def _get_docs(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],

View File

@ -11,7 +11,7 @@ try:
from lark import Lark, Transformer, v_args from lark import Lark, Transformer, v_args
except ImportError: except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc] def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
"""Dummy decorator for when lark is not installed.""" """Dummy decorator for when lark is not installed."""
return lambda _: None return lambda _: None

View File

@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import override
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -330,6 +331,7 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(msg) raise ValueError(msg)
return values return values
@override
def _get_docs( def _get_docs(
self, self,
question: str, question: str,

View File

@ -11,6 +11,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import override
from langchain.chains.router.base import RouterChain from langchain.chains.router.base import RouterChain
@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
""" """
return self.routing_keys return self.routing_keys
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
results = self.vectorstore.similarity_search(_input, k=1) results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]} return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -63,6 +64,7 @@ class TransformChain(Chain):
""" """
return self.output_variables return self.output_variables
@override
def _call( def _call(
self, self,
inputs: dict[str, str], inputs: dict[str, str],
@ -70,6 +72,7 @@ class TransformChain(Chain):
) -> dict[str, str]: ) -> dict[str, str]:
return self.transform_cb(inputs) return self.transform_cb(inputs)
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],

View File

@ -331,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
""" """
return ["prediction", "reference"] return ["prediction", "reference"]
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -355,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
score = self._compute_score(vectors) score = self._compute_score(vectors)
return {"score": score} return {"score": score}
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -382,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
score = self._compute_score(vectors) score = self._compute_score(vectors)
return {"score": score} return {"score": score}
@override
def _evaluate_strings( def _evaluate_strings(
self, self,
*, *,
@ -416,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
) )
return self._prepare_output(result) return self._prepare_output(result)
@override
async def _aevaluate_strings( async def _aevaluate_strings(
self, self,
*, *,
@ -478,6 +482,7 @@ class PairwiseEmbeddingDistanceEvalChain(
"""Return the evaluation name.""" """Return the evaluation name."""
return f"pairwise_embedding_{self.distance_metric.value}_distance" return f"pairwise_embedding_{self.distance_metric.value}_distance"
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -505,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
score = self._compute_score(vectors) score = self._compute_score(vectors)
return {"score": score} return {"score": score}
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -532,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
score = self._compute_score(vectors) score = self._compute_score(vectors)
return {"score": score} return {"score": score}
@override
def _evaluate_string_pairs( def _evaluate_string_pairs(
self, self,
*, *,
@ -567,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
) )
return self._prepare_output(result) return self._prepare_output(result)
@override
async def _aevaluate_string_pairs( async def _aevaluate_string_pairs(
self, self,
*, *,

View File

@ -1,6 +1,8 @@
import string import string
from typing import Any from typing import Any
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator from langchain.evaluation.schema import StringEvaluator
@ -78,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
""" """
return "exact_match" return "exact_match"
@override
def _evaluate_strings( # type: ignore[override] def _evaluate_strings( # type: ignore[override]
self, self,
*, *,

View File

@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
""" # noqa: E501 """ # noqa: E501
def __init__(self, **kwargs: Any) -> None: def __init__(self, **_: Any) -> None:
"""Initializes the JsonSchemaEvaluator. """Initializes the JsonSchemaEvaluator.
Args:
kwargs: Additional keyword arguments.
Raises: Raises:
ImportError: If the jsonschema package is not installed. ImportError: If the jsonschema package is not installed.
""" """

View File

@ -1,6 +1,8 @@
import re import re
from typing import Any from typing import Any
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator from langchain.evaluation.schema import StringEvaluator
@ -70,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
""" """
return "regex_match" return "regex_match"
@override
def _evaluate_strings( # type: ignore[override] def _evaluate_strings( # type: ignore[override]
self, self,
*, *,

View File

@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
""" """
return f"{self.distance.value}_distance" return f"{self.distance.value}_distance"
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
""" """
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])} return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
""" """
return f"pairwise_{self.distance.value}_distance" return f"pairwise_{self.distance.value}_distance"
@override
def _call( def _call(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
} }
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, Any], inputs: dict[str, Any],
@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]), "score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
} }
@override
def _evaluate_string_pairs( def _evaluate_string_pairs(
self, self,
*, *,
@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
) )
return self._prepare_output(result) return self._prepare_output(result)
@override
async def _aevaluate_string_pairs( async def _aevaluate_string_pairs(
self, self,
*, *,

View File

@ -79,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}
@override
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return key-value pairs given the text input to the chain.""" """Return key-value pairs given the text input to the chain."""
buffer = await self.abuffer() buffer = await self.abuffer()
@ -133,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}

View File

@ -2,6 +2,7 @@ from typing import Any, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}

View File

@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
from langchain_core.prompts import BasePromptTemplate from langchain_core.prompts import BasePromptTemplate
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
@ -133,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
if self.return_messages: if self.return_messages:

View File

@ -3,6 +3,7 @@ from typing import Any, Union
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin from langchain.memory.summary import SummarizerMixin
@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
buffer = self.chat_memory.messages buffer = self.chat_memory.messages
@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
) )
return {self.memory_key: final_buffer} return {self.memory_key: final_buffer}
@override
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Asynchronously return key-value pairs given the text input to the chain.""" """Asynchronously return key-value pairs given the text input to the chain."""
buffer = await self.chat_memory.aget_messages() buffer = await self.chat_memory.aget_messages()

View File

@ -3,6 +3,7 @@ from typing import Any
from langchain_core._api import deprecated from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory from langchain.memory.chat_memory import BaseChatMemory
@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
""" """
return [self.memory_key] return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]: def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.""" """Return history buffer."""
return {self.memory_key: self.buffer} return {self.memory_key: self.buffer}

View File

@ -110,7 +110,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None, prompt: Optional[PromptTemplate] = None,
get_input: Optional[Callable[[str, Document], str]] = None, get_input: Optional[Callable[[str, Document], str]] = None,
llm_chain_kwargs: Optional[dict] = None, llm_chain_kwargs: Optional[dict] = None, # noqa: ARG003
) -> LLMChainExtractor: ) -> LLMChainExtractor:
"""Initialize from LLM.""" """Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt() _prompt = prompt if prompt is not None else _get_default_chain_prompt()

View File

@ -9,6 +9,7 @@ from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import get_from_dict_or_env from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, model_validator from pydantic import ConfigDict, model_validator
from typing_extensions import override
@deprecated( @deprecated(
@ -98,6 +99,7 @@ class CohereRerank(BaseDocumentCompressor):
for res in results for res in results
] ]
@override
def compress_documents( def compress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],

View File

@ -7,6 +7,7 @@ from typing import Optional
from langchain_core.callbacks import Callbacks from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict from pydantic import ConfigDict
from typing_extensions import override
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
@ -25,6 +26,7 @@ class CrossEncoderReranker(BaseDocumentCompressor):
extra="forbid", extra="forbid",
) )
@override
def compress_documents( def compress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],

View File

@ -6,6 +6,7 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init from langchain_core.utils import pre_init
from pydantic import ConfigDict, Field from pydantic import ConfigDict, Field
from typing_extensions import override
def _get_similarity_function() -> Callable: def _get_similarity_function() -> Callable:
@ -50,6 +51,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
raise ValueError(msg) raise ValueError(msg)
return values return values
@override
def compress_documents( def compress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
@ -93,6 +95,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
stateful_documents[i].state["query_similarity_score"] = similarity[i] stateful_documents[i].state["query_similarity_score"] = similarity[i]
return [stateful_documents[i] for i in included_idxs] return [stateful_documents[i] for i in included_idxs]
@override
async def acompress_documents( async def acompress_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],

View File

@ -67,7 +67,7 @@ class MultiQueryRetriever(BaseRetriever):
retriever: BaseRetriever, retriever: BaseRetriever,
llm: BaseLanguageModel, llm: BaseLanguageModel,
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT, prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: Optional[str] = None, parser_key: Optional[str] = None, # noqa: ARG003
include_original: bool = False, # noqa: FBT001,FBT002 include_original: bool = False, # noqa: FBT001,FBT002
) -> "MultiQueryRetriever": ) -> "MultiQueryRetriever":
"""Initialize from llm using default template. """Initialize from llm using default template.

View File

@ -10,6 +10,7 @@ from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator from pydantic import Field, model_validator
from typing_extensions import override
from langchain.storage._lc_store import create_kv_docstore from langchain.storage._lc_store import create_kv_docstore
@ -54,6 +55,7 @@ class MultiVectorRetriever(BaseRetriever):
values["docstore"] = docstore values["docstore"] = docstore
return values return values
@override
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
@ -91,6 +93,7 @@ class MultiVectorRetriever(BaseRetriever):
docs = self.docstore.mget(ids) docs = self.docstore.mget(ids)
return [d for d in docs if d is not None] return [d for d in docs if d is not None]
@override
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,

View File

@ -10,6 +10,7 @@ from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict, Field from pydantic import ConfigDict, Field
from typing_extensions import override
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float: def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
@ -128,6 +129,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
result.append(buffered_doc) result.append(buffered_doc)
return result return result
@override
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,
@ -142,6 +144,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
docs_and_scores.update(self.get_salient_docs(query)) docs_and_scores.update(self.get_salient_docs(query))
return self._get_rescored_docs(docs_and_scores) return self._get_rescored_docs(docs_and_scores)
@override
async def _aget_relevant_documents( async def _aget_relevant_documents(
self, self,
query: str, query: str,

View File

@ -19,7 +19,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
total: int, total: int,
ncols: int = 50, ncols: int = 50,
end_with: str = "\n", end_with: str = "\n",
**kwargs: Any,
): ):
"""Initialize the progress bar. """Initialize the progress bar.

View File

@ -355,6 +355,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
feedback.evaluator_info[RUN_KEY] = output[RUN_KEY] feedback.evaluator_info[RUN_KEY] = output[RUN_KEY]
return feedback return feedback
@override
def evaluate_run( def evaluate_run(
self, self,
run: Run, run: Run,
@ -372,6 +373,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
# TODO: Add run ID once we can declare it via callbacks # TODO: Add run ID once we can declare it via callbacks
) )
@override
async def aevaluate_run( async def aevaluate_run(
self, self,
run: Run, run: Run,

View File

@ -1,7 +1,7 @@
from typing import Any from typing import Any
def __getattr__(name: str = "") -> Any: def __getattr__(_: str = "") -> Any:
msg = ( msg = (
"This tool has been moved to langchain experiment. " "This tool has been moved to langchain experiment. "
"This tool has access to a python REPL. " "This tool has access to a python REPL. "

View File

@ -145,6 +145,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.ruff.lint] [tool.ruff.lint]
select = [ select = [
"A", # flake8-builtins "A", # flake8-builtins
"ARG", # flake8-unused-arguments
"ASYNC", # flake8-async "ASYNC", # flake8-async
"B", # flake8-bugbear "B", # flake8-bugbear
"C4", # flake8-comprehensions "C4", # flake8-comprehensions

View File

@ -3,6 +3,7 @@
import math import math
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from typing_extensions import override
fake_texts = ["foo", "bar", "baz"] fake_texts = ["foo", "bar", "baz"]
@ -18,6 +19,7 @@ class FakeEmbeddings(Embeddings):
async def aembed_documents(self, texts: list[str]) -> list[list[float]]: async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts) return self.embed_documents(texts)
@override
def embed_query(self, text: str) -> list[float]: def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings. """Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0]. Embeddings are identical to embed_documents(texts)[0].

View File

@ -25,6 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.utils import add from langchain_core.runnables.utils import add
from langchain_core.tools import Tool, tool from langchain_core.tools import Tool, tool
from langchain_core.tracers import RunLog, RunLogPatch from langchain_core.tracers import RunLog, RunLogPatch
from typing_extensions import override
from langchain.agents import ( from langchain.agents import (
AgentExecutor, AgentExecutor,
@ -48,6 +49,7 @@ class FakeListLLM(LLM):
responses: list[str] responses: list[str]
i: int = -1 i: int = -1
@override
def _call( def _call(
self, self,
prompt: str, prompt: str,
@ -462,7 +464,7 @@ async def test_runnable_agent() -> None:
], ],
) )
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser.""" """A parser."""
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message") return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
@ -569,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
], ],
) )
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser.""" """A parser."""
return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@ -681,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
], ],
) )
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]: def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser.""" """A parser."""
return cast("Union[AgentFinish, AgentAction]", next(parser_responses)) return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@ -1032,7 +1034,7 @@ async def test_openai_agent_tools_agent() -> None:
], ],
) )
GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc] GenericFakeChatModel.bind_tools = lambda self, _: self # type: ignore[assignment,misc]
model = GenericFakeChatModel(messages=infinite_cycle) model = GenericFakeChatModel(messages=infinite_cycle)
@tool @tool

View File

@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.utils import add from langchain_core.runnables.utils import add
from langchain_core.tools import Tool from langchain_core.tools import Tool
from typing_extensions import override
from langchain.agents import AgentExecutor, AgentType, initialize_agent from langchain.agents import AgentExecutor, AgentType, initialize_agent
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -19,6 +20,7 @@ class FakeListLLM(LLM):
responses: list[str] responses: list[str]
i: int = -1 i: int = -1
@override
def _call( def _call(
self, self,
prompt: str, prompt: str,

View File

@ -364,7 +364,7 @@ def test_agent_iterator_failing_tool() -> None:
tools = [ tools = [
Tool( Tool(
name="FailingTool", name="FailingTool",
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError func=lambda _: 1 / 0, # This tool will raise a ZeroDivisionError
description="A tool that fails", description="A tool that fails",
), ),
] ]

View File

@ -8,7 +8,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
@tool @tool
def my_tool(query: str) -> str: def my_tool(query: str) -> str: # noqa: ARG001
"""A fake tool.""" """A fake tool."""
return "fake tool" return "fake tool"

View File

@ -141,11 +141,11 @@ def test_valid_action_and_answer_raises_exception() -> None:
def test_from_chains() -> None: def test_from_chains() -> None:
"""Test initializing from chains.""" """Test initializing from chains."""
chain_configs = [ chain_configs = [
Tool(name="foo", func=lambda x: "foo", description="foobar1"), Tool(name="foo", func=lambda _x: "foo", description="foobar1"),
Tool(name="bar", func=lambda x: "bar", description="foobar2"), Tool(name="bar", func=lambda _x: "bar", description="foobar2"),
] ]
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs) agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2" expected_tools_prompt = "foo(_x) - foobar1\nbar(_x) - foobar2"
expected_tool_names = "foo, bar" expected_tool_names = "foo, bar"
expected_template = "\n\n".join( expected_template = "\n\n".join(
[ [

View File

@ -7,7 +7,7 @@ import pytest
from langchain.agents.openai_assistant import OpenAIAssistantRunnable from langchain.agents.openai_assistant import OpenAIAssistantRunnable
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any: def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:
client = AsyncMock() if use_async else MagicMock() client = AsyncMock() if use_async else MagicMock()
mock_assistant = MagicMock() mock_assistant = MagicMock()
mock_assistant.id = "abc123" mock_assistant.id = "abc123"

View File

@ -7,6 +7,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import override
class BaseFakeCallbackHandler(BaseModel): class BaseFakeCallbackHandler(BaseModel):
@ -135,6 +136,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks.""" """Whether to ignore retriever callbacks."""
return self.ignore_retriever_ return self.ignore_retriever_
@override
def on_llm_start( def on_llm_start(
self, self,
*args: Any, *args: Any,
@ -142,6 +144,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_start_common() self.on_llm_start_common()
@override
def on_llm_new_token( def on_llm_new_token(
self, self,
*args: Any, *args: Any,
@ -149,6 +152,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_new_token_common() self.on_llm_new_token_common()
@override
def on_llm_end( def on_llm_end(
self, self,
*args: Any, *args: Any,
@ -156,6 +160,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_end_common() self.on_llm_end_common()
@override
def on_llm_error( def on_llm_error(
self, self,
*args: Any, *args: Any,
@ -163,6 +168,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_llm_error_common() self.on_llm_error_common()
@override
def on_retry( def on_retry(
self, self,
*args: Any, *args: Any,
@ -170,6 +176,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retry_common() self.on_retry_common()
@override
def on_chain_start( def on_chain_start(
self, self,
*args: Any, *args: Any,
@ -177,6 +184,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_start_common() self.on_chain_start_common()
@override
def on_chain_end( def on_chain_end(
self, self,
*args: Any, *args: Any,
@ -184,6 +192,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_end_common() self.on_chain_end_common()
@override
def on_chain_error( def on_chain_error(
self, self,
*args: Any, *args: Any,
@ -191,6 +200,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_chain_error_common() self.on_chain_error_common()
@override
def on_tool_start( def on_tool_start(
self, self,
*args: Any, *args: Any,
@ -198,6 +208,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_start_common() self.on_tool_start_common()
@override
def on_tool_end( def on_tool_end(
self, self,
*args: Any, *args: Any,
@ -205,6 +216,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_end_common() self.on_tool_end_common()
@override
def on_tool_error( def on_tool_error(
self, self,
*args: Any, *args: Any,
@ -212,6 +224,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_tool_error_common() self.on_tool_error_common()
@override
def on_agent_action( def on_agent_action(
self, self,
*args: Any, *args: Any,
@ -219,6 +232,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_agent_action_common() self.on_agent_action_common()
@override
def on_agent_finish( def on_agent_finish(
self, self,
*args: Any, *args: Any,
@ -226,6 +240,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_agent_finish_common() self.on_agent_finish_common()
@override
def on_text( def on_text(
self, self,
*args: Any, *args: Any,
@ -233,6 +248,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_text_common() self.on_text_common()
@override
def on_retriever_start( def on_retriever_start(
self, self,
*args: Any, *args: Any,
@ -240,6 +256,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retriever_start_common() self.on_retriever_start_common()
@override
def on_retriever_end( def on_retriever_end(
self, self,
*args: Any, *args: Any,
@ -247,6 +264,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any: ) -> Any:
self.on_retriever_end_common() self.on_retriever_end_common()
@override
def on_retriever_error( def on_retriever_error(
self, self,
*args: Any, *args: Any,
@ -259,6 +277,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler): class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
@override
def on_chat_model_start( def on_chat_model_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@ -290,6 +309,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return self.ignore_agent_ return self.ignore_agent_
@override
async def on_retry( async def on_retry(
self, self,
*args: Any, *args: Any,
@ -297,6 +317,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> Any: ) -> Any:
self.on_retry_common() self.on_retry_common()
@override
async def on_llm_start( async def on_llm_start(
self, self,
*args: Any, *args: Any,
@ -304,6 +325,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_start_common() self.on_llm_start_common()
@override
async def on_llm_new_token( async def on_llm_new_token(
self, self,
*args: Any, *args: Any,
@ -311,6 +333,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_new_token_common() self.on_llm_new_token_common()
@override
async def on_llm_end( async def on_llm_end(
self, self,
*args: Any, *args: Any,
@ -318,6 +341,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_end_common() self.on_llm_end_common()
@override
async def on_llm_error( async def on_llm_error(
self, self,
*args: Any, *args: Any,
@ -325,6 +349,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_llm_error_common() self.on_llm_error_common()
@override
async def on_chain_start( async def on_chain_start(
self, self,
*args: Any, *args: Any,
@ -332,6 +357,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_start_common() self.on_chain_start_common()
@override
async def on_chain_end( async def on_chain_end(
self, self,
*args: Any, *args: Any,
@ -339,6 +365,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_end_common() self.on_chain_end_common()
@override
async def on_chain_error( async def on_chain_error(
self, self,
*args: Any, *args: Any,
@ -346,6 +373,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_chain_error_common() self.on_chain_error_common()
@override
async def on_tool_start( async def on_tool_start(
self, self,
*args: Any, *args: Any,
@ -353,6 +381,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_start_common() self.on_tool_start_common()
@override
async def on_tool_end( async def on_tool_end(
self, self,
*args: Any, *args: Any,
@ -360,6 +389,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_end_common() self.on_tool_end_common()
@override
async def on_tool_error( async def on_tool_error(
self, self,
*args: Any, *args: Any,
@ -367,6 +397,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_tool_error_common() self.on_tool_error_common()
@override
async def on_agent_action( async def on_agent_action(
self, self,
*args: Any, *args: Any,
@ -374,6 +405,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_agent_action_common() self.on_agent_action_common()
@override
async def on_agent_finish( async def on_agent_finish(
self, self,
*args: Any, *args: Any,
@ -381,6 +413,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None: ) -> None:
self.on_agent_finish_common() self.on_agent_finish_common()
@override
async def on_text( async def on_text(
self, self,
*args: Any, *args: Any,

View File

@ -3,6 +3,7 @@ import re
from typing import Optional from typing import Optional
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
from typing_extensions import override
from langchain.callbacks import FileCallbackHandler from langchain.callbacks import FileCallbackHandler
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -25,6 +26,7 @@ class FakeChain(Chain):
"""Output key of bar.""" """Output key of bar."""
return self.the_output_keys return self.the_output_keys
@override
def _call( def _call(
self, self,
inputs: dict[str, str], inputs: dict[str, str],

View File

@ -2,6 +2,7 @@ from typing import Any, Optional
import pytest import pytest
from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.callbacks import CallbackManagerForChainRun
from typing_extensions import override
from langchain.callbacks import StdOutCallbackHandler from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.base import Chain from langchain.chains.base import Chain
@ -24,6 +25,7 @@ class FakeChain(Chain):
"""Output key of bar.""" """Output key of bar."""
return self.the_output_keys return self.the_output_keys
@override
def _call( def _call(
self, self,
inputs: dict[str, str], inputs: dict[str, str],

View File

@ -7,6 +7,7 @@ import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs from langchain_core.tracers.context import collect_runs
from typing_extensions import override
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import RUN_KEY from langchain.schema import RUN_KEY
@ -21,6 +22,7 @@ class FakeMemory(BaseMemory):
"""Return baz variable.""" """Return baz variable."""
return ["baz"] return ["baz"]
@override
def load_memory_variables( def load_memory_variables(
self, self,
inputs: Optional[dict[str, Any]] = None, inputs: Optional[dict[str, Any]] = None,
@ -52,6 +54,7 @@ class FakeChain(Chain):
"""Output key of bar.""" """Output key of bar."""
return self.the_output_keys return self.the_output_keys
@override
def _call( def _call(
self, self,
inputs: dict[str, str], inputs: dict[str, str],

View File

@ -19,7 +19,7 @@ def _fake_docs_len_func(docs: list[Document]) -> int:
return len(_fake_combine_docs_func(docs)) return len(_fake_combine_docs_func(docs))
def _fake_combine_docs_func(docs: list[Document], **kwargs: Any) -> str: def _fake_combine_docs_func(docs: list[Document], **_: Any) -> str:
return "".join([d.page_content for d in docs]) return "".join([d.page_content for d in docs])

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM from langchain_core.language_models import LLM
from langchain_core.memory import BaseMemory from langchain_core.memory import BaseMemory
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from typing_extensions import override
from langchain.chains.conversation.base import ConversationChain from langchain.chains.conversation.base import ConversationChain
from langchain.memory.buffer import ConversationBufferMemory from langchain.memory.buffer import ConversationBufferMemory
@ -26,6 +27,7 @@ class DummyLLM(LLM):
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "dummy" return "dummy"
@override
def _call( def _call(
self, self,
prompt: str, prompt: str,

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import BaseLLM from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult from langchain_core.outputs import Generation, LLMResult
from typing_extensions import override
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.hyde.prompts import PROMPT_MAP from langchain.chains.hyde.prompts import PROMPT_MAP
@ -18,10 +19,12 @@ from langchain.chains.hyde.prompts import PROMPT_MAP
class FakeEmbeddings(Embeddings): class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests.""" """Fake embedding class for tests."""
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return random floats.""" """Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)] return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
@override
def embed_query(self, text: str) -> list[float]: def embed_query(self, text: str) -> list[float]:
"""Return random floats.""" """Return random floats."""
return list(np.random.uniform(0, 1, 10)) return list(np.random.uniform(0, 1, 10))
@ -32,6 +35,7 @@ class FakeLLM(BaseLLM):
n: int = 1 n: int = 1
@override
def _generate( def _generate(
self, self,
prompts: list[str], prompts: list[str],
@ -41,6 +45,7 @@ class FakeLLM(BaseLLM):
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@override
async def _agenerate( async def _agenerate(
self, self,
prompts: list[str], prompts: list[str],

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun, AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from typing_extensions import override
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
@ -32,6 +33,7 @@ class FakeChain(Chain):
"""Input keys this chain returns.""" """Input keys this chain returns."""
return self.output_variables return self.output_variables
@override
def _call( def _call(
self, self,
inputs: dict[str, str], inputs: dict[str, str],
@ -43,6 +45,7 @@ class FakeChain(Chain):
outputs[var] = f"{' '.join(variables)}foo" outputs[var] = f"{' '.join(variables)}foo"
return outputs return outputs
@override
async def _acall( async def _acall(
self, self,
inputs: dict[str, str], inputs: dict[str, str],

View File

@ -4,6 +4,7 @@ from collections.abc import Iterator
from langchain_core.document_loaders import BaseBlobParser, Blob from langchain_core.document_loaders import BaseBlobParser, Blob
from langchain_core.documents import Document from langchain_core.documents import Document
from typing_extensions import override
def test_base_blob_parser() -> None: def test_base_blob_parser() -> None:
@ -12,6 +13,7 @@ def test_base_blob_parser() -> None:
class MyParser(BaseBlobParser): class MyParser(BaseBlobParser):
"""A simple parser that returns a single document.""" """A simple parser that returns a single document."""
@override
def lazy_parse(self, blob: Blob) -> Iterator[Document]: def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazy parsing interface.""" """Lazy parsing interface."""
yield Document( yield Document(

View File

@ -7,12 +7,14 @@ import warnings
import pytest import pytest
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from typing_extensions import override
from langchain.embeddings import CacheBackedEmbeddings from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage.in_memory import InMemoryStore from langchain.storage.in_memory import InMemoryStore
class MockEmbeddings(Embeddings): class MockEmbeddings(Embeddings):
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
# Simulate embedding documents # Simulate embedding documents
embeddings: list[list[float]] = [] embeddings: list[list[float]] = []
@ -23,6 +25,7 @@ class MockEmbeddings(Embeddings):
embeddings.append([len(text), len(text) + 1]) embeddings.append([len(text), len(text) + 1])
return embeddings return embeddings
@override
def embed_query(self, text: str) -> list[float]: def embed_query(self, text: str) -> list[float]:
# Simulate embedding a query # Simulate embedding a query
return [5.0, 6.0] return [5.0, 6.0]

View File

@ -9,6 +9,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from pydantic import Field from pydantic import Field
from typing_extensions import override
from langchain.evaluation.agents.trajectory_eval_chain import ( from langchain.evaluation.agents.trajectory_eval_chain import (
TrajectoryEval, TrajectoryEval,
@ -43,6 +44,7 @@ class _FakeTrajectoryChatModel(FakeChatModel):
sequential_responses: Optional[bool] = False sequential_responses: Optional[bool] = False
response_index: int = 0 response_index: int = 0
@override
def _call( def _call(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -13,6 +13,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.indexing.api import _abatch, _get_document_with_hash from langchain_core.indexing.api import _abatch, _get_document_with_hash
from langchain_core.vectorstores import VST, VectorStore from langchain_core.vectorstores import VST, VectorStore
from typing_extensions import override
from langchain.indexes import aindex, index from langchain.indexes import aindex, index
from langchain.indexes._sql_record_manager import SQLRecordManager from langchain.indexes._sql_record_manager import SQLRecordManager
@ -45,18 +46,21 @@ class InMemoryVectorStore(VectorStore):
self.store: dict[str, Document] = {} self.store: dict[str, Document] = {}
self.permit_upserts = permit_upserts self.permit_upserts = permit_upserts
@override
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs.""" """Delete the given documents from the store using their IDs."""
if ids: if ids:
for _id in ids: for _id in ids:
self.store.pop(_id, None) self.store.pop(_id, None)
@override
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None: async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs.""" """Delete the given documents from the store using their IDs."""
if ids: if ids:
for _id in ids: for _id in ids:
self.store.pop(_id, None) self.store.pop(_id, None)
@override
def add_documents( def add_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],
@ -81,6 +85,7 @@ class InMemoryVectorStore(VectorStore):
return list(ids) return list(ids)
@override
async def aadd_documents( async def aadd_documents(
self, self,
documents: Sequence[Document], documents: Sequence[Document],

View File

@ -16,11 +16,13 @@ from langchain_core.messages import (
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor from langchain_core.runnables import run_in_executor
from typing_extensions import override
class FakeChatModel(SimpleChatModel): class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes.""" """Fake Chat Model wrapper for testing purposes."""
@override
def _call( def _call(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -30,6 +32,7 @@ class FakeChatModel(SimpleChatModel):
) -> str: ) -> str:
return "fake response" return "fake response"
@override
async def _agenerate( async def _agenerate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -74,6 +77,7 @@ class GenericFakeChatModel(BaseChatModel):
into message chunks. into message chunks.
""" """
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, cast
from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM from langchain_core.language_models.llms import LLM
from pydantic import model_validator from pydantic import model_validator
from typing_extensions import override
class FakeLLM(LLM): class FakeLLM(LLM):
@ -32,6 +33,7 @@ class FakeLLM(LLM):
"""Return type of llm.""" """Return type of llm."""
return "fake" return "fake"
@override
def _call( def _call(
self, self,
prompt: str, prompt: str,

View File

@ -7,6 +7,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from typing_extensions import override
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
@ -166,6 +167,7 @@ async def test_callback_handlers() -> None:
# Required to implement since this is an abstract method # Required to implement since this is an abstract method
pass pass
@override
async def on_llm_new_token( async def on_llm_new_token(
self, self,
token: str, token: str,

View File

@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from typing_extensions import override
from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser
@ -21,6 +22,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
parse_count: int = 0 # Number of times parse has been called parse_count: int = 0 # Number of times parse has been called
attemp_count_before_success: int # Number of times to fail before succeeding attemp_count_before_success: int # Number of times to fail before succeeding
@override
def parse(self, *args: Any, **kwargs: Any) -> str: def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1 self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success: if self.parse_count <= self.attemp_count_before_success:
@ -62,7 +64,7 @@ def test_output_fixing_parser_parse(
def test_output_fixing_parser_from_llm() -> None: def test_output_fixing_parser_from_llm() -> None:
def fake_llm(prompt: str) -> AIMessage: def fake_llm(_: str) -> AIMessage:
return AIMessage("2024-07-08T00:00:00.000000Z") return AIMessage("2024-07-08T00:00:00.000000Z")
llm = RunnableLambda(fake_llm) llm = RunnableLambda(fake_llm)

View File

@ -7,6 +7,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompt_values import PromptValue, StringPromptValue from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from typing_extensions import override
from langchain.output_parsers.boolean import BooleanOutputParser from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser from langchain.output_parsers.datetime import DatetimeOutputParser
@ -25,6 +26,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
attemp_count_before_success: int # Number of times to fail before succeeding attemp_count_before_success: int # Number of times to fail before succeeding
error_msg: str = "error" error_msg: str = "error"
@override
def parse(self, *args: Any, **kwargs: Any) -> str: def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1 self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success: if self.parse_count <= self.attemp_count_before_success:

View File

@ -14,6 +14,7 @@ from langchain_core.structured_query import (
StructuredQuery, StructuredQuery,
Visitor, Visitor,
) )
from typing_extensions import override
from langchain.chains.query_constructor.schema import AttributeInfo from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers import SelfQueryRetriever from langchain.retrievers import SelfQueryRetriever
@ -61,6 +62,7 @@ class FakeTranslator(Visitor):
class InMemoryVectorstoreWithSearch(InMemoryVectorStore): class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
@override
def similarity_search( def similarity_search(
self, self,
query: str, query: str,

View File

@ -1,5 +1,8 @@
from typing import Any
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
class SequentialRetriever(BaseRetriever): class SequentialRetriever(BaseRetriever):
@ -8,17 +11,21 @@ class SequentialRetriever(BaseRetriever):
sequential_responses: list[list[Document]] sequential_responses: list[list[Document]]
response_index: int = 0 response_index: int = 0
def _get_relevant_documents( # type: ignore[override] @override
def _get_relevant_documents(
self, self,
query: str, query: str,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
if self.response_index >= len(self.sequential_responses): if self.response_index >= len(self.sequential_responses):
return [] return []
self.response_index += 1 self.response_index += 1
return self.sequential_responses[self.response_index - 1] return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override] @override
async def _aget_relevant_documents(
self, self,
query: str, query: str,
**kwargs: Any,
) -> list[Document]: ) -> list[Document]:
return self._get_relevant_documents(query) return self._get_relevant_documents(query)

View File

@ -3,6 +3,7 @@ from typing import Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
from langchain.retrievers.ensemble import EnsembleRetriever from langchain.retrievers.ensemble import EnsembleRetriever
@ -10,6 +11,7 @@ from langchain.retrievers.ensemble import EnsembleRetriever
class MockRetriever(BaseRetriever): class MockRetriever(BaseRetriever):
docs: list[Document] docs: list[Document]
@override
def _get_relevant_documents( def _get_relevant_documents(
self, self,
query: str, query: str,

View File

@ -1,6 +1,7 @@
from typing import Any, Callable from typing import Any, Callable
from langchain_core.documents import Document from langchain_core.documents import Document
from typing_extensions import override
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
from langchain.storage import InMemoryStore from langchain.storage import InMemoryStore
@ -15,6 +16,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._identity_fn return self._identity_fn
@override
def similarity_search( def similarity_search(
self, self,
query: str, query: str,
@ -26,6 +28,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
return [] return []
return [res] return [res]
@override
def similarity_search_with_score( def similarity_search_with_score(
self, self,
query: str, query: str,

View File

@ -3,6 +3,7 @@ from typing import Any
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_text_splitters.character import CharacterTextSplitter from langchain_text_splitters.character import CharacterTextSplitter
from typing_extensions import override
from langchain.retrievers import ParentDocumentRetriever from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore from langchain.storage import InMemoryStore
@ -10,6 +11,7 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
class InMemoryVectorstoreWithSearch(InMemoryVectorStore): class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
@override
def similarity_search( def similarity_search(
self, self,
query: str, query: str,
@ -21,6 +23,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
return [] return []
return [res] return [res]
@override
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]: def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]:
print(documents) # noqa: T201 print(documents) # noqa: T201
return super().add_documents( return super().add_documents(

View File

@ -8,6 +8,7 @@ import pytest
from langchain_core.documents import Document from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore from langchain_core.vectorstores import VectorStore
from typing_extensions import override
from langchain.retrievers.time_weighted_retriever import ( from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever, TimeWeightedVectorStoreRetriever,
@ -31,6 +32,7 @@ def _get_example_memories(k: int = 4) -> list[Document]:
class MockVectorStore(VectorStore): class MockVectorStore(VectorStore):
"""Mock invalid vector store.""" """Mock invalid vector store."""
@override
def add_texts( def add_texts(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -39,6 +41,7 @@ class MockVectorStore(VectorStore):
) -> list[str]: ) -> list[str]:
return list(texts) return list(texts)
@override
def similarity_search( def similarity_search(
self, self,
query: str, query: str,
@ -48,6 +51,7 @@ class MockVectorStore(VectorStore):
return [] return []
@classmethod @classmethod
@override
def from_texts( def from_texts(
cls: type["MockVectorStore"], cls: type["MockVectorStore"],
texts: list[str], texts: list[str],
@ -57,6 +61,7 @@ class MockVectorStore(VectorStore):
) -> "MockVectorStore": ) -> "MockVectorStore":
return cls() return cls()
@override
def _similarity_search_with_relevance_scores( def _similarity_search_with_relevance_scores(
self, self,
query: str, query: str,

View File

@ -38,7 +38,7 @@ repo_dict = {
} }
def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate: def repo_lookup(owner_repo_commit: str, **_: Any) -> ChatPromptTemplate:
return repo_dict[owner_repo_commit] return repo_dict[owner_repo_commit]

View File

@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatResult
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
from langchain.runnables.openai_functions import OpenAIFunctionsRouter from langchain.runnables.openai_functions import OpenAIFunctionsRouter
@ -15,6 +16,7 @@ class FakeChatOpenAI(BaseChatModel):
def _llm_type(self) -> str: def _llm_type(self) -> str:
return "fake-openai-chat-model" return "fake-openai-chat-model"
@override
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],

View File

@ -3,16 +3,14 @@
import uuid import uuid
from collections.abc import Iterator from collections.abc import Iterator
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Optional, Union from typing import Any
from unittest import mock from unittest import mock
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from langchain_core.language_models import BaseLanguageModel
from langsmith.client import Client from langsmith.client import Client
from langsmith.schemas import Dataset, Example from langsmith.schemas import Dataset, Example
from langchain.chains.base import Chain
from langchain.chains.transform import TransformChain from langchain.chains.transform import TransformChain
from langchain.smith.evaluation.runner_utils import ( from langchain.smith.evaluation.runner_utils import (
InputFormatError, InputFormatError,
@ -243,7 +241,7 @@ def test_run_chat_model_all_formats(inputs: dict[str, Any]) -> None:
@freeze_time("2023-01-01") @freeze_time("2023-01-01")
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None: async def test_arun_on_dataset() -> None:
dataset = Dataset( dataset = Dataset(
id=uuid.uuid4(), id=uuid.uuid4(),
name="test", name="test",
@ -298,22 +296,20 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
), ),
] ]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset: def mock_read_dataset(*_: Any, **__: Any) -> Dataset:
return dataset return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]: def mock_list_examples(*_: Any, **__: Any) -> Iterator[Example]:
return iter(examples) return iter(examples)
async def mock_arun_chain( async def mock_arun_chain(
example: Example, example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain], *_: Any,
tags: Optional[list[str]] = None, **__: Any,
callbacks: Optional[Any] = None,
**kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
return {"result": f"Result for example {example.id}"} return {"result": f"Result for example {example.id}"}
def mock_create_project(*args: Any, **kwargs: Any) -> Any: def mock_create_project(*_: Any, **__: Any) -> Any:
proj = mock.MagicMock() proj = mock.MagicMock()
proj.id = "123" proj.id = "123"
return proj return proj

View File

@ -8,13 +8,13 @@ from langchain.tools.render import (
@tool @tool
def search(query: str) -> str: def search(query: str) -> str: # noqa: ARG001
"""Lookup things online.""" """Lookup things online."""
return "foo" return "foo"
@tool @tool
def calculator(expression: str) -> str: def calculator(expression: str) -> str: # noqa: ARG001
"""Do math.""" """Do math."""
return "bar" return "bar"