chore(langchain): add ruff rules ARG (#32110)

See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg

Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
Christophe Bornet 2025-07-27 00:32:34 +02:00 committed by GitHub
parent a2ad5aca41
commit efdfa00d10
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
79 changed files with 241 additions and 62 deletions

View File

@ -116,8 +116,8 @@ class BaseSingleActionAgent(BaseModel):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
**_: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -125,7 +125,6 @@ class BaseSingleActionAgent(BaseModel):
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
@ -168,6 +167,7 @@ class BaseSingleActionAgent(BaseModel):
"""Return Identifier of an agent type."""
raise NotImplementedError
@override
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent.
@ -289,8 +289,8 @@ class BaseMultiActionAgent(BaseModel):
def return_stopped_response(
self,
early_stopping_method: str,
intermediate_steps: list[tuple[AgentAction, str]],
**kwargs: Any,
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
**_: Any,
) -> AgentFinish:
"""Return response when agent has been stopped due to max iterations.
@ -298,7 +298,6 @@ class BaseMultiActionAgent(BaseModel):
early_stopping_method: Method to use for early stopping.
intermediate_steps: Steps the LLM has taken to date,
along with observations.
**kwargs: User inputs.
Returns:
AgentFinish: Agent finish object.
@ -317,6 +316,7 @@ class BaseMultiActionAgent(BaseModel):
"""Return Identifier of an agent type."""
raise NotImplementedError
@override
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().model_dump()
@ -651,6 +651,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
"""
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
@override
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
@ -735,6 +736,7 @@ class Agent(BaseSingleActionAgent):
allowed_tools: Optional[list[str]] = None
"""Allowed tools for the agent. If None, all tools are allowed."""
@override
def dict(self, **kwargs: Any) -> builtins.dict:
"""Return dictionary representation of agent."""
_dict = super().dict()
@ -750,18 +752,6 @@ class Agent(BaseSingleActionAgent):
"""Return values of the agent."""
return ["output"]
def _fix_text(self, text: str) -> str:
"""Fix the text.
Args:
text: Text to fix.
Returns:
str: Fixed text.
"""
msg = "fix_text not implemented for this agent."
raise ValueError(msg)
@property
def _stop(self) -> list[str]:
return [
@ -1021,6 +1011,7 @@ class ExceptionTool(BaseTool):
description: str = "Exception tool"
"""Description of the tool."""
@override
def _run(
self,
query: str,
@ -1028,6 +1019,7 @@ class ExceptionTool(BaseTool):
) -> str:
return query
@override
async def _arun(
self,
query: str,
@ -1188,6 +1180,7 @@ class AgentExecutor(Chain):
return cast("RunnableAgentType", self.agent)
return self.agent
@override
def save(self, file_path: Union[Path, str]) -> None:
"""Raise error - saving not supported for Agent Executors.
@ -1218,7 +1211,7 @@ class AgentExecutor(Chain):
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
async_: bool = False, # arg kept for backwards compat, but ignored
async_: bool = False, # noqa: ARG002 arg kept for backwards compat, but ignored
) -> AgentExecutorIterator:
"""Enables iteration over steps taken to reach final output.

View File

@ -13,6 +13,7 @@ from langchain_core.prompts.chat import (
)
from langchain_core.tools import BaseTool
from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentOutputParser
@ -65,6 +66,7 @@ class ChatAgent(Agent):
return agent_scratchpad
@classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ChatOutputParser()

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentOutputParser
@ -35,6 +36,7 @@ class ConversationalAgent(Agent):
"""Output parser for the agent."""
@classmethod
@override
def _get_default_output_parser(
cls,
ai_prefix: str = "AI",

View File

@ -20,6 +20,7 @@ from langchain_core.prompts.chat import (
)
from langchain_core.tools import BaseTool
from pydantic import Field
from typing_extensions import override
from langchain.agents.agent import Agent, AgentOutputParser
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
@ -42,6 +43,7 @@ class ConversationalChatAgent(Agent):
"""Template for the tool response."""
@classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ConvoOutputParser()

View File

@ -12,6 +12,7 @@ from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool, Tool
from langchain_core.tools.render import render_text_description
from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
@ -51,6 +52,7 @@ class ZeroShotAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
@classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return MRKLOutputParser()

View File

@ -4,6 +4,7 @@ from typing import Any
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.agents.format_scratchpad import (
format_to_openai_function_messages,
@ -55,6 +56,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer.

View File

@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from langchain_core.tools import BaseTool, Tool
from pydantic import Field
from typing_extensions import override
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
@classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return ReActOutputParser()
@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
return AgentType.REACT_DOCSTORE
@classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
return WIKI_PROMPT
@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
"""Agent for the ReAct TextWorld chain."""
@classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Return default prompt."""
return TEXTWORLD_PROMPT

View File

@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.tools import BaseTool, Tool
from pydantic import Field
from typing_extensions import override
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
from langchain.agents.agent_types import AgentType
@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
@classmethod
@override
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
return SelfAskOutputParser()
@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
return AgentType.SELF_ASK_WITH_SEARCH
@classmethod
@override
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
"""Prompt does not depend on tools."""
return PROMPT

View File

@ -71,6 +71,7 @@ class StructuredChatAgent(Agent):
pass
@classmethod
@override
def _get_default_output_parser(
cls,
llm: Optional[BaseLanguageModel] = None,

View File

@ -7,6 +7,7 @@ from langchain_core.callbacks import (
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool, tool
from typing_extensions import override
class InvalidTool(BaseTool):
@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
description: str = "Called when tool name is invalid. Suggests valid tool names."
"""Description of the tool."""
@override
def _run(
self,
requested_tool_name: str,
@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
f"try one of [{available_tool_names_str}]."
)
@override
async def _arun(
self,
requested_tool_name: str,

View File

@ -4,6 +4,7 @@ import sys
from typing import Any, Optional
from langchain_core.callbacks import StreamingStdOutCallbackHandler
from typing_extensions import override
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
@ -63,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
self.stream_prefix = stream_prefix
self.answer_reached = False
@override
def on_llm_start(
self,
serialized: dict[str, Any],
@ -72,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
"""Run when LLM starts running."""
self.answer_reached = False
@override
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""

View File

@ -388,7 +388,7 @@ except ImportError:
class APIChain: # type: ignore[no-redef]
"""Raise an ImportError if APIChain is used without langchain_community."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
def __init__(self, *_: Any, **__: Any) -> None:
"""Raise an ImportError if APIChain is used without langchain_community."""
msg = (
"To use the APIChain, you must install the langchain_community package."

View File

@ -83,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
"""
return [self.output_key]
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
"""Return the prompt length given the documents passed in.
This can be used by a caller to determine whether passing in a list

View File

@ -402,6 +402,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
return docs[:num_docs]
@override
def _get_docs(
self,
question: str,
@ -416,6 +417,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
)
return self._reduce_tokens_below_limit(docs)
@override
async def _aget_docs(
self,
question: str,
@ -512,6 +514,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
)
return values
@override
def _get_docs(
self,
question: str,

View File

@ -1,4 +1,4 @@
def __getattr__(name: str = "") -> None:
def __getattr__(_: str = "") -> None:
"""Raise an error on import since is deprecated."""
msg = (
"This module has been moved to langchain-experimental. "

View File

@ -1,4 +1,4 @@
def __getattr__(name: str = "") -> None:
def __getattr__(_: str = "") -> None:
"""Raise an error on import since is deprecated."""
msg = (
"This module has been moved to langchain-experimental. "

View File

@ -39,7 +39,7 @@ try:
from langchain_community.llms.loading import load_llm, load_llm_from_config
except ImportError:
def load_llm(*args: Any, **kwargs: Any) -> None:
def load_llm(*_: Any, **__: Any) -> None:
"""Import error for load_llm."""
msg = (
"To use this load_llm functionality you must install the "
@ -48,7 +48,7 @@ except ImportError:
)
raise ImportError(msg)
def load_llm_from_config(*args: Any, **kwargs: Any) -> None:
def load_llm_from_config(*_: Any, **__: Any) -> None:
"""Import error for load_llm_from_config."""
msg = (
"To use this load_llm_from_config functionality you must install the "

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks import (
)
from langchain_core.utils import check_package_version, get_from_dict_or_env
from pydantic import Field, model_validator
from typing_extensions import override
from langchain.chains.base import Chain
@ -105,6 +106,7 @@ class OpenAIModerationChain(Chain):
return error_str
return text
@override
def _call(
self,
inputs: dict[str, Any],

View File

@ -16,6 +16,7 @@ from langchain_core.documents import Document
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import BasePromptTemplate
from pydantic import ConfigDict, model_validator
from typing_extensions import override
from langchain.chains import ReduceDocumentsChain
from langchain.chains.base import Chain
@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""
return [self.input_docs_key, self.question_key]
@override
def _get_docs(
self,
inputs: dict[str, Any],
@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
"""Get docs to run questioning over."""
return inputs.pop(self.input_docs_key)
@override
async def _aget_docs(
self,
inputs: dict[str, Any],

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks import (
from langchain_core.documents import Document
from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator
from typing_extensions import override
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
return docs[:num_docs]
@override
def _get_docs(
self,
inputs: dict[str, Any],

View File

@ -11,7 +11,7 @@ try:
from lark import Lark, Transformer, v_args
except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
"""Dummy decorator for when lark is not installed."""
return lambda _: None

View File

@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict, Field, model_validator
from typing_extensions import override
from langchain.chains.base import Chain
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
@ -330,6 +331,7 @@ class VectorDBQA(BaseRetrievalQA):
raise ValueError(msg)
return values
@override
def _get_docs(
self,
question: str,

View File

@ -11,6 +11,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict
from typing_extensions import override
from langchain.chains.router.base import RouterChain
@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
"""
return self.routing_keys
@override
def _call(
self,
inputs: dict[str, Any],
@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
results = self.vectorstore.similarity_search(_input, k=1)
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
@override
async def _acall(
self,
inputs: dict[str, Any],

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from pydantic import Field
from typing_extensions import override
from langchain.chains.base import Chain
@ -63,6 +64,7 @@ class TransformChain(Chain):
"""
return self.output_variables
@override
def _call(
self,
inputs: dict[str, str],
@ -70,6 +72,7 @@ class TransformChain(Chain):
) -> dict[str, str]:
return self.transform_cb(inputs)
@override
async def _acall(
self,
inputs: dict[str, Any],

View File

@ -331,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
"""
return ["prediction", "reference"]
@override
def _call(
self,
inputs: dict[str, Any],
@ -355,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
score = self._compute_score(vectors)
return {"score": score}
@override
async def _acall(
self,
inputs: dict[str, Any],
@ -382,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
score = self._compute_score(vectors)
return {"score": score}
@override
def _evaluate_strings(
self,
*,
@ -416,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
)
return self._prepare_output(result)
@override
async def _aevaluate_strings(
self,
*,
@ -478,6 +482,7 @@ class PairwiseEmbeddingDistanceEvalChain(
"""Return the evaluation name."""
return f"pairwise_embedding_{self.distance_metric.value}_distance"
@override
def _call(
self,
inputs: dict[str, Any],
@ -505,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
score = self._compute_score(vectors)
return {"score": score}
@override
async def _acall(
self,
inputs: dict[str, Any],
@ -532,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
score = self._compute_score(vectors)
return {"score": score}
@override
def _evaluate_string_pairs(
self,
*,
@ -567,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
)
return self._prepare_output(result)
@override
async def _aevaluate_string_pairs(
self,
*,

View File

@ -1,6 +1,8 @@
import string
from typing import Any
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator
@ -78,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
"""
return "exact_match"
@override
def _evaluate_strings( # type: ignore[override]
self,
*,

View File

@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
""" # noqa: E501
def __init__(self, **kwargs: Any) -> None:
def __init__(self, **_: Any) -> None:
"""Initializes the JsonSchemaEvaluator.
Args:
kwargs: Additional keyword arguments.
Raises:
ImportError: If the jsonschema package is not installed.
"""

View File

@ -1,6 +1,8 @@
import re
from typing import Any
from typing_extensions import override
from langchain.evaluation.schema import StringEvaluator
@ -70,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
"""
return "regex_match"
@override
def _evaluate_strings( # type: ignore[override]
self,
*,

View File

@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""
return f"{self.distance.value}_distance"
@override
def _call(
self,
inputs: dict[str, Any],
@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
"""
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
@override
async def _acall(
self,
inputs: dict[str, Any],
@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"""
return f"pairwise_{self.distance.value}_distance"
@override
def _call(
self,
inputs: dict[str, Any],
@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
}
@override
async def _acall(
self,
inputs: dict[str, Any],
@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
}
@override
def _evaluate_string_pairs(
self,
*,
@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
)
return self._prepare_output(result)
@override
async def _aevaluate_string_pairs(
self,
*,

View File

@ -79,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}
@override
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return key-value pairs given the text input to the chain."""
buffer = await self.abuffer()
@ -133,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

View File

@ -2,6 +2,7 @@ from typing import Any, Union
from langchain_core._api import deprecated
from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory
@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

View File

@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
from langchain_core.prompts import BasePromptTemplate
from langchain_core.utils import pre_init
from pydantic import BaseModel
from typing_extensions import override
from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
@ -133,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
if self.return_messages:

View File

@ -3,6 +3,7 @@ from typing import Any, Union
from langchain_core._api import deprecated
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.utils import pre_init
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.summary import SummarizerMixin
@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
buffer = self.chat_memory.messages
@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
)
return {self.memory_key: final_buffer}
@override
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Asynchronously return key-value pairs given the text input to the chain."""
buffer = await self.chat_memory.aget_messages()

View File

@ -3,6 +3,7 @@ from typing import Any
from langchain_core._api import deprecated
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, get_buffer_string
from typing_extensions import override
from langchain.memory.chat_memory import BaseChatMemory
@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
"""
return [self.memory_key]
@override
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Return history buffer."""
return {self.memory_key: self.buffer}

View File

@ -110,7 +110,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
llm: BaseLanguageModel,
prompt: Optional[PromptTemplate] = None,
get_input: Optional[Callable[[str, Document], str]] = None,
llm_chain_kwargs: Optional[dict] = None,
llm_chain_kwargs: Optional[dict] = None, # noqa: ARG003
) -> LLMChainExtractor:
"""Initialize from LLM."""
_prompt = prompt if prompt is not None else _get_default_chain_prompt()

View File

@ -9,6 +9,7 @@ from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.utils import get_from_dict_or_env
from pydantic import ConfigDict, model_validator
from typing_extensions import override
@deprecated(
@ -98,6 +99,7 @@ class CohereRerank(BaseDocumentCompressor):
for res in results
]
@override
def compress_documents(
self,
documents: Sequence[Document],

View File

@ -7,6 +7,7 @@ from typing import Optional
from langchain_core.callbacks import Callbacks
from langchain_core.documents import BaseDocumentCompressor, Document
from pydantic import ConfigDict
from typing_extensions import override
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
@ -25,6 +26,7 @@ class CrossEncoderReranker(BaseDocumentCompressor):
extra="forbid",
)
@override
def compress_documents(
self,
documents: Sequence[Document],

View File

@ -6,6 +6,7 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.embeddings import Embeddings
from langchain_core.utils import pre_init
from pydantic import ConfigDict, Field
from typing_extensions import override
def _get_similarity_function() -> Callable:
@ -50,6 +51,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
raise ValueError(msg)
return values
@override
def compress_documents(
self,
documents: Sequence[Document],
@ -93,6 +95,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
stateful_documents[i].state["query_similarity_score"] = similarity[i]
return [stateful_documents[i] for i in included_idxs]
@override
async def acompress_documents(
self,
documents: Sequence[Document],

View File

@ -67,7 +67,7 @@ class MultiQueryRetriever(BaseRetriever):
retriever: BaseRetriever,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: Optional[str] = None,
parser_key: Optional[str] = None, # noqa: ARG003
include_original: bool = False, # noqa: FBT001,FBT002
) -> "MultiQueryRetriever":
"""Initialize from llm using default template.

View File

@ -10,6 +10,7 @@ from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore, ByteStore
from langchain_core.vectorstores import VectorStore
from pydantic import Field, model_validator
from typing_extensions import override
from langchain.storage._lc_store import create_kv_docstore
@ -54,6 +55,7 @@ class MultiVectorRetriever(BaseRetriever):
values["docstore"] = docstore
return values
@override
def _get_relevant_documents(
self,
query: str,
@ -91,6 +93,7 @@ class MultiVectorRetriever(BaseRetriever):
docs = self.docstore.mget(ids)
return [d for d in docs if d is not None]
@override
async def _aget_relevant_documents(
self,
query: str,

View File

@ -10,6 +10,7 @@ from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from pydantic import ConfigDict, Field
from typing_extensions import override
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
@ -128,6 +129,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
result.append(buffered_doc)
return result
@override
def _get_relevant_documents(
self,
query: str,
@ -142,6 +144,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
docs_and_scores.update(self.get_salient_docs(query))
return self._get_rescored_docs(docs_and_scores)
@override
async def _aget_relevant_documents(
self,
query: str,

View File

@ -19,7 +19,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
total: int,
ncols: int = 50,
end_with: str = "\n",
**kwargs: Any,
):
"""Initialize the progress bar.

View File

@ -355,6 +355,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
feedback.evaluator_info[RUN_KEY] = output[RUN_KEY]
return feedback
@override
def evaluate_run(
self,
run: Run,
@ -372,6 +373,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
# TODO: Add run ID once we can declare it via callbacks
)
@override
async def aevaluate_run(
self,
run: Run,

View File

@ -1,7 +1,7 @@
from typing import Any
def __getattr__(name: str = "") -> Any:
def __getattr__(_: str = "") -> Any:
msg = (
"This tool has been moved to langchain experiment. "
"This tool has access to a python REPL. "

View File

@ -145,6 +145,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
[tool.ruff.lint]
select = [
"A", # flake8-builtins
"ARG", # flake8-unused-arguments
"ASYNC", # flake8-async
"B", # flake8-bugbear
"C4", # flake8-comprehensions

View File

@ -3,6 +3,7 @@
import math
from langchain_core.embeddings import Embeddings
from typing_extensions import override
fake_texts = ["foo", "bar", "baz"]
@ -18,6 +19,7 @@ class FakeEmbeddings(Embeddings):
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
return self.embed_documents(texts)
@override
def embed_query(self, text: str) -> list[float]:
"""Return constant query embeddings.
Embeddings are identical to embed_documents(texts)[0].

View File

@ -25,6 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables.utils import add
from langchain_core.tools import Tool, tool
from langchain_core.tracers import RunLog, RunLogPatch
from typing_extensions import override
from langchain.agents import (
AgentExecutor,
@ -48,6 +49,7 @@ class FakeListLLM(LLM):
responses: list[str]
i: int = -1
@override
def _call(
self,
prompt: str,
@ -462,7 +464,7 @@ async def test_runnable_agent() -> None:
],
)
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
@ -569,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
],
)
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@ -681,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
],
)
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
"""A parser."""
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
@ -1032,7 +1034,7 @@ async def test_openai_agent_tools_agent() -> None:
],
)
GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc]
GenericFakeChatModel.bind_tools = lambda self, _: self # type: ignore[assignment,misc]
model = GenericFakeChatModel(messages=infinite_cycle)
@tool

View File

@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables.utils import add
from langchain_core.tools import Tool
from typing_extensions import override
from langchain.agents import AgentExecutor, AgentType, initialize_agent
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -19,6 +20,7 @@ class FakeListLLM(LLM):
responses: list[str]
i: int = -1
@override
def _call(
self,
prompt: str,

View File

@ -364,7 +364,7 @@ def test_agent_iterator_failing_tool() -> None:
tools = [
Tool(
name="FailingTool",
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
func=lambda _: 1 / 0, # This tool will raise a ZeroDivisionError
description="A tool that fails",
),
]

View File

@ -8,7 +8,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
@tool
def my_tool(query: str) -> str:
def my_tool(query: str) -> str: # noqa: ARG001
"""A fake tool."""
return "fake tool"

View File

@ -141,11 +141,11 @@ def test_valid_action_and_answer_raises_exception() -> None:
def test_from_chains() -> None:
"""Test initializing from chains."""
chain_configs = [
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
Tool(name="foo", func=lambda _x: "foo", description="foobar1"),
Tool(name="bar", func=lambda _x: "bar", description="foobar2"),
]
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2"
expected_tools_prompt = "foo(_x) - foobar1\nbar(_x) - foobar2"
expected_tool_names = "foo, bar"
expected_template = "\n\n".join(
[

View File

@ -7,7 +7,7 @@ import pytest
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:
client = AsyncMock() if use_async else MagicMock()
mock_assistant = MagicMock()
mock_assistant.id = "abc123"

View File

@ -7,6 +7,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from typing_extensions import override
class BaseFakeCallbackHandler(BaseModel):
@ -135,6 +136,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
@override
def on_llm_start(
self,
*args: Any,
@ -142,6 +144,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_start_common()
@override
def on_llm_new_token(
self,
*args: Any,
@ -149,6 +152,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_new_token_common()
@override
def on_llm_end(
self,
*args: Any,
@ -156,6 +160,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_end_common()
@override
def on_llm_error(
self,
*args: Any,
@ -163,6 +168,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_llm_error_common()
@override
def on_retry(
self,
*args: Any,
@ -170,6 +176,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retry_common()
@override
def on_chain_start(
self,
*args: Any,
@ -177,6 +184,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_start_common()
@override
def on_chain_end(
self,
*args: Any,
@ -184,6 +192,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_end_common()
@override
def on_chain_error(
self,
*args: Any,
@ -191,6 +200,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_chain_error_common()
@override
def on_tool_start(
self,
*args: Any,
@ -198,6 +208,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_start_common()
@override
def on_tool_end(
self,
*args: Any,
@ -205,6 +216,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_end_common()
@override
def on_tool_error(
self,
*args: Any,
@ -212,6 +224,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_tool_error_common()
@override
def on_agent_action(
self,
*args: Any,
@ -219,6 +232,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_action_common()
@override
def on_agent_finish(
self,
*args: Any,
@ -226,6 +240,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_agent_finish_common()
@override
def on_text(
self,
*args: Any,
@ -233,6 +248,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_text_common()
@override
def on_retriever_start(
self,
*args: Any,
@ -240,6 +256,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_start_common()
@override
def on_retriever_end(
self,
*args: Any,
@ -247,6 +264,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
) -> Any:
self.on_retriever_end_common()
@override
def on_retriever_error(
self,
*args: Any,
@ -259,6 +277,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
@override
def on_chat_model_start(
self,
serialized: dict[str, Any],
@ -290,6 +309,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
@override
async def on_retry(
self,
*args: Any,
@ -297,6 +317,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> Any:
self.on_retry_common()
@override
async def on_llm_start(
self,
*args: Any,
@ -304,6 +325,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_start_common()
@override
async def on_llm_new_token(
self,
*args: Any,
@ -311,6 +333,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_new_token_common()
@override
async def on_llm_end(
self,
*args: Any,
@ -318,6 +341,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_end_common()
@override
async def on_llm_error(
self,
*args: Any,
@ -325,6 +349,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_llm_error_common()
@override
async def on_chain_start(
self,
*args: Any,
@ -332,6 +357,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_start_common()
@override
async def on_chain_end(
self,
*args: Any,
@ -339,6 +365,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_end_common()
@override
async def on_chain_error(
self,
*args: Any,
@ -346,6 +373,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_chain_error_common()
@override
async def on_tool_start(
self,
*args: Any,
@ -353,6 +381,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_start_common()
@override
async def on_tool_end(
self,
*args: Any,
@ -360,6 +389,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_end_common()
@override
async def on_tool_error(
self,
*args: Any,
@ -367,6 +397,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_tool_error_common()
@override
async def on_agent_action(
self,
*args: Any,
@ -374,6 +405,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_action_common()
@override
async def on_agent_finish(
self,
*args: Any,
@ -381,6 +413,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
) -> None:
self.on_agent_finish_common()
@override
async def on_text(
self,
*args: Any,

View File

@ -3,6 +3,7 @@ import re
from typing import Optional
from langchain_core.callbacks import CallbackManagerForChainRun
from typing_extensions import override
from langchain.callbacks import FileCallbackHandler
from langchain.chains.base import Chain
@ -25,6 +26,7 @@ class FakeChain(Chain):
"""Output key of bar."""
return self.the_output_keys
@override
def _call(
self,
inputs: dict[str, str],

View File

@ -2,6 +2,7 @@ from typing import Any, Optional
import pytest
from langchain_core.callbacks import CallbackManagerForChainRun
from typing_extensions import override
from langchain.callbacks import StdOutCallbackHandler
from langchain.chains.base import Chain
@ -24,6 +25,7 @@ class FakeChain(Chain):
"""Output key of bar."""
return self.the_output_keys
@override
def _call(
self,
inputs: dict[str, str],

View File

@ -7,6 +7,7 @@ import pytest
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.memory import BaseMemory
from langchain_core.tracers.context import collect_runs
from typing_extensions import override
from langchain.chains.base import Chain
from langchain.schema import RUN_KEY
@ -21,6 +22,7 @@ class FakeMemory(BaseMemory):
"""Return baz variable."""
return ["baz"]
@override
def load_memory_variables(
self,
inputs: Optional[dict[str, Any]] = None,
@ -52,6 +54,7 @@ class FakeChain(Chain):
"""Output key of bar."""
return self.the_output_keys
@override
def _call(
self,
inputs: dict[str, str],

View File

@ -19,7 +19,7 @@ def _fake_docs_len_func(docs: list[Document]) -> int:
return len(_fake_combine_docs_func(docs))
def _fake_combine_docs_func(docs: list[Document], **kwargs: Any) -> str:
def _fake_combine_docs_func(docs: list[Document], **_: Any) -> str:
return "".join([d.page_content for d in docs])

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LLM
from langchain_core.memory import BaseMemory
from langchain_core.prompts.prompt import PromptTemplate
from typing_extensions import override
from langchain.chains.conversation.base import ConversationChain
from langchain.memory.buffer import ConversationBufferMemory
@ -26,6 +27,7 @@ class DummyLLM(LLM):
def _llm_type(self) -> str:
return "dummy"
@override
def _call(
self,
prompt: str,

View File

@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, LLMResult
from typing_extensions import override
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
from langchain.chains.hyde.prompts import PROMPT_MAP
@ -18,10 +19,12 @@ from langchain.chains.hyde.prompts import PROMPT_MAP
class FakeEmbeddings(Embeddings):
"""Fake embedding class for tests."""
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Return random floats."""
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
@override
def embed_query(self, text: str) -> list[float]:
"""Return random floats."""
return list(np.random.uniform(0, 1, 10))
@ -32,6 +35,7 @@ class FakeLLM(BaseLLM):
n: int = 1
@override
def _generate(
self,
prompts: list[str],
@ -41,6 +45,7 @@ class FakeLLM(BaseLLM):
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@override
async def _agenerate(
self,
prompts: list[str],

View File

@ -8,6 +8,7 @@ from langchain_core.callbacks.manager import (
AsyncCallbackManagerForChainRun,
CallbackManagerForChainRun,
)
from typing_extensions import override
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
@ -32,6 +33,7 @@ class FakeChain(Chain):
"""Input keys this chain returns."""
return self.output_variables
@override
def _call(
self,
inputs: dict[str, str],
@ -43,6 +45,7 @@ class FakeChain(Chain):
outputs[var] = f"{' '.join(variables)}foo"
return outputs
@override
async def _acall(
self,
inputs: dict[str, str],

View File

@ -4,6 +4,7 @@ from collections.abc import Iterator
from langchain_core.document_loaders import BaseBlobParser, Blob
from langchain_core.documents import Document
from typing_extensions import override
def test_base_blob_parser() -> None:
@ -12,6 +13,7 @@ def test_base_blob_parser() -> None:
class MyParser(BaseBlobParser):
"""A simple parser that returns a single document."""
@override
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
"""Lazy parsing interface."""
yield Document(

View File

@ -7,12 +7,14 @@ import warnings
import pytest
from langchain_core.embeddings import Embeddings
from typing_extensions import override
from langchain.embeddings import CacheBackedEmbeddings
from langchain.storage.in_memory import InMemoryStore
class MockEmbeddings(Embeddings):
@override
def embed_documents(self, texts: list[str]) -> list[list[float]]:
# Simulate embedding documents
embeddings: list[list[float]] = []
@ -23,6 +25,7 @@ class MockEmbeddings(Embeddings):
embeddings.append([len(text), len(text) + 1])
return embeddings
@override
def embed_query(self, text: str) -> list[float]:
# Simulate embedding a query
return [5.0, 6.0]

View File

@ -9,6 +9,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.tools import tool
from pydantic import Field
from typing_extensions import override
from langchain.evaluation.agents.trajectory_eval_chain import (
TrajectoryEval,
@ -43,6 +44,7 @@ class _FakeTrajectoryChatModel(FakeChatModel):
sequential_responses: Optional[bool] = False
response_index: int = 0
@override
def _call(
self,
messages: list[BaseMessage],

View File

@ -13,6 +13,7 @@ from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.indexing.api import _abatch, _get_document_with_hash
from langchain_core.vectorstores import VST, VectorStore
from typing_extensions import override
from langchain.indexes import aindex, index
from langchain.indexes._sql_record_manager import SQLRecordManager
@ -45,18 +46,21 @@ class InMemoryVectorStore(VectorStore):
self.store: dict[str, Document] = {}
self.permit_upserts = permit_upserts
@override
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
@override
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
"""Delete the given documents from the store using their IDs."""
if ids:
for _id in ids:
self.store.pop(_id, None)
@override
def add_documents(
self,
documents: Sequence[Document],
@ -81,6 +85,7 @@ class InMemoryVectorStore(VectorStore):
return list(ids)
@override
async def aadd_documents(
self,
documents: Sequence[Document],

View File

@ -16,11 +16,13 @@ from langchain_core.messages import (
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import run_in_executor
from typing_extensions import override
class FakeChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
@override
def _call(
self,
messages: list[BaseMessage],
@ -30,6 +32,7 @@ class FakeChatModel(SimpleChatModel):
) -> str:
return "fake response"
@override
async def _agenerate(
self,
messages: list[BaseMessage],
@ -74,6 +77,7 @@ class GenericFakeChatModel(BaseChatModel):
into message chunks.
"""
@override
def _generate(
self,
messages: list[BaseMessage],

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, cast
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from pydantic import model_validator
from typing_extensions import override
class FakeLLM(LLM):
@ -32,6 +33,7 @@ class FakeLLM(LLM):
"""Return type of llm."""
return "fake"
@override
def _call(
self,
prompt: str,

View File

@ -7,6 +7,7 @@ from uuid import UUID
from langchain_core.callbacks.base import AsyncCallbackHandler
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from typing_extensions import override
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
@ -166,6 +167,7 @@ async def test_callback_handlers() -> None:
# Required to implement since this is an abstract method
pass
@override
async def on_llm_new_token(
self,
token: str,

View File

@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from typing_extensions import override
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
@ -21,6 +22,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
parse_count: int = 0 # Number of times parse has been called
attemp_count_before_success: int # Number of times to fail before succeeding
@override
def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success:
@ -62,7 +64,7 @@ def test_output_fixing_parser_parse(
def test_output_fixing_parser_from_llm() -> None:
def fake_llm(prompt: str) -> AIMessage:
def fake_llm(_: str) -> AIMessage:
return AIMessage("2024-07-08T00:00:00.000000Z")
llm = RunnableLambda(fake_llm)

View File

@ -7,6 +7,7 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from typing_extensions import override
from langchain.output_parsers.boolean import BooleanOutputParser
from langchain.output_parsers.datetime import DatetimeOutputParser
@ -25,6 +26,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
attemp_count_before_success: int # Number of times to fail before succeeding
error_msg: str = "error"
@override
def parse(self, *args: Any, **kwargs: Any) -> str:
self.parse_count += 1
if self.parse_count <= self.attemp_count_before_success:

View File

@ -14,6 +14,7 @@ from langchain_core.structured_query import (
StructuredQuery,
Visitor,
)
from typing_extensions import override
from langchain.chains.query_constructor.schema import AttributeInfo
from langchain.retrievers import SelfQueryRetriever
@ -61,6 +62,7 @@ class FakeTranslator(Visitor):
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
@override
def similarity_search(
self,
query: str,

View File

@ -1,5 +1,8 @@
from typing import Any
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
class SequentialRetriever(BaseRetriever):
@ -8,17 +11,21 @@ class SequentialRetriever(BaseRetriever):
sequential_responses: list[list[Document]]
response_index: int = 0
def _get_relevant_documents( # type: ignore[override]
@override
def _get_relevant_documents(
self,
query: str,
**kwargs: Any,
) -> list[Document]:
if self.response_index >= len(self.sequential_responses):
return []
self.response_index += 1
return self.sequential_responses[self.response_index - 1]
async def _aget_relevant_documents( # type: ignore[override]
@override
async def _aget_relevant_documents(
self,
query: str,
**kwargs: Any,
) -> list[Document]:
return self._get_relevant_documents(query)

View File

@ -3,6 +3,7 @@ from typing import Optional
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from typing_extensions import override
from langchain.retrievers.ensemble import EnsembleRetriever
@ -10,6 +11,7 @@ from langchain.retrievers.ensemble import EnsembleRetriever
class MockRetriever(BaseRetriever):
docs: list[Document]
@override
def _get_relevant_documents(
self,
query: str,

View File

@ -1,6 +1,7 @@
from typing import Any, Callable
from langchain_core.documents import Document
from typing_extensions import override
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
from langchain.storage import InMemoryStore
@ -15,6 +16,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
def _select_relevance_score_fn(self) -> Callable[[float], float]:
return self._identity_fn
@override
def similarity_search(
self,
query: str,
@ -26,6 +28,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
return []
return [res]
@override
def similarity_search_with_score(
self,
query: str,

View File

@ -3,6 +3,7 @@ from typing import Any
from langchain_core.documents import Document
from langchain_text_splitters.character import CharacterTextSplitter
from typing_extensions import override
from langchain.retrievers import ParentDocumentRetriever
from langchain.storage import InMemoryStore
@ -10,6 +11,7 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
@override
def similarity_search(
self,
query: str,
@ -21,6 +23,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
return []
return [res]
@override
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]:
print(documents) # noqa: T201
return super().add_documents(

View File

@ -8,6 +8,7 @@ import pytest
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from typing_extensions import override
from langchain.retrievers.time_weighted_retriever import (
TimeWeightedVectorStoreRetriever,
@ -31,6 +32,7 @@ def _get_example_memories(k: int = 4) -> list[Document]:
class MockVectorStore(VectorStore):
"""Mock invalid vector store."""
@override
def add_texts(
self,
texts: Iterable[str],
@ -39,6 +41,7 @@ class MockVectorStore(VectorStore):
) -> list[str]:
return list(texts)
@override
def similarity_search(
self,
query: str,
@ -48,6 +51,7 @@ class MockVectorStore(VectorStore):
return []
@classmethod
@override
def from_texts(
cls: type["MockVectorStore"],
texts: list[str],
@ -57,6 +61,7 @@ class MockVectorStore(VectorStore):
) -> "MockVectorStore":
return cls()
@override
def _similarity_search_with_relevance_scores(
self,
query: str,

View File

@ -38,7 +38,7 @@ repo_dict = {
}
def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate:
def repo_lookup(owner_repo_commit: str, **_: Any) -> ChatPromptTemplate:
return repo_dict[owner_repo_commit]

View File

@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from pytest_mock import MockerFixture
from syrupy.assertion import SnapshotAssertion
from typing_extensions import override
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
@ -15,6 +16,7 @@ class FakeChatOpenAI(BaseChatModel):
def _llm_type(self) -> str:
return "fake-openai-chat-model"
@override
def _generate(
self,
messages: list[BaseMessage],

View File

@ -3,16 +3,14 @@
import uuid
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import Any, Optional, Union
from typing import Any
from unittest import mock
import pytest
from freezegun import freeze_time
from langchain_core.language_models import BaseLanguageModel
from langsmith.client import Client
from langsmith.schemas import Dataset, Example
from langchain.chains.base import Chain
from langchain.chains.transform import TransformChain
from langchain.smith.evaluation.runner_utils import (
InputFormatError,
@ -243,7 +241,7 @@ def test_run_chat_model_all_formats(inputs: dict[str, Any]) -> None:
@freeze_time("2023-01-01")
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
async def test_arun_on_dataset() -> None:
dataset = Dataset(
id=uuid.uuid4(),
name="test",
@ -298,22 +296,20 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
),
]
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
def mock_read_dataset(*_: Any, **__: Any) -> Dataset:
return dataset
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
def mock_list_examples(*_: Any, **__: Any) -> Iterator[Example]:
return iter(examples)
async def mock_arun_chain(
example: Example,
llm_or_chain: Union[BaseLanguageModel, Chain],
tags: Optional[list[str]] = None,
callbacks: Optional[Any] = None,
**kwargs: Any,
*_: Any,
**__: Any,
) -> dict[str, Any]:
return {"result": f"Result for example {example.id}"}
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
def mock_create_project(*_: Any, **__: Any) -> Any:
proj = mock.MagicMock()
proj.id = "123"
return proj

View File

@ -8,13 +8,13 @@ from langchain.tools.render import (
@tool
def search(query: str) -> str:
def search(query: str) -> str: # noqa: ARG001
"""Lookup things online."""
return "foo"
@tool
def calculator(expression: str) -> str:
def calculator(expression: str) -> str: # noqa: ARG001
"""Do math."""
return "bar"