mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-07 12:06:43 +00:00
chore(langchain): add ruff rules ARG (#32110)
See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
a2ad5aca41
commit
efdfa00d10
@ -116,8 +116,8 @@ class BaseSingleActionAgent(BaseModel):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
|
||||
**_: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
|
||||
@ -125,7 +125,6 @@ class BaseSingleActionAgent(BaseModel):
|
||||
early_stopping_method: Method to use for early stopping.
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
AgentFinish: Agent finish object.
|
||||
@ -168,6 +167,7 @@ class BaseSingleActionAgent(BaseModel):
|
||||
"""Return Identifier of an agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent.
|
||||
|
||||
@ -289,8 +289,8 @@ class BaseMultiActionAgent(BaseModel):
|
||||
def return_stopped_response(
|
||||
self,
|
||||
early_stopping_method: str,
|
||||
intermediate_steps: list[tuple[AgentAction, str]],
|
||||
**kwargs: Any,
|
||||
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
|
||||
**_: Any,
|
||||
) -> AgentFinish:
|
||||
"""Return response when agent has been stopped due to max iterations.
|
||||
|
||||
@ -298,7 +298,6 @@ class BaseMultiActionAgent(BaseModel):
|
||||
early_stopping_method: Method to use for early stopping.
|
||||
intermediate_steps: Steps the LLM has taken to date,
|
||||
along with observations.
|
||||
**kwargs: User inputs.
|
||||
|
||||
Returns:
|
||||
AgentFinish: Agent finish object.
|
||||
@ -317,6 +316,7 @@ class BaseMultiActionAgent(BaseModel):
|
||||
"""Return Identifier of an agent type."""
|
||||
raise NotImplementedError
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().model_dump()
|
||||
@ -651,6 +651,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
||||
"""
|
||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
@ -735,6 +736,7 @@ class Agent(BaseSingleActionAgent):
|
||||
allowed_tools: Optional[list[str]] = None
|
||||
"""Allowed tools for the agent. If None, all tools are allowed."""
|
||||
|
||||
@override
|
||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||
"""Return dictionary representation of agent."""
|
||||
_dict = super().dict()
|
||||
@ -750,18 +752,6 @@ class Agent(BaseSingleActionAgent):
|
||||
"""Return values of the agent."""
|
||||
return ["output"]
|
||||
|
||||
def _fix_text(self, text: str) -> str:
|
||||
"""Fix the text.
|
||||
|
||||
Args:
|
||||
text: Text to fix.
|
||||
|
||||
Returns:
|
||||
str: Fixed text.
|
||||
"""
|
||||
msg = "fix_text not implemented for this agent."
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def _stop(self) -> list[str]:
|
||||
return [
|
||||
@ -1021,6 +1011,7 @@ class ExceptionTool(BaseTool):
|
||||
description: str = "Exception tool"
|
||||
"""Description of the tool."""
|
||||
|
||||
@override
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
@ -1028,6 +1019,7 @@ class ExceptionTool(BaseTool):
|
||||
) -> str:
|
||||
return query
|
||||
|
||||
@override
|
||||
async def _arun(
|
||||
self,
|
||||
query: str,
|
||||
@ -1188,6 +1180,7 @@ class AgentExecutor(Chain):
|
||||
return cast("RunnableAgentType", self.agent)
|
||||
return self.agent
|
||||
|
||||
@override
|
||||
def save(self, file_path: Union[Path, str]) -> None:
|
||||
"""Raise error - saving not supported for Agent Executors.
|
||||
|
||||
@ -1218,7 +1211,7 @@ class AgentExecutor(Chain):
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
include_run_info: bool = False,
|
||||
async_: bool = False, # arg kept for backwards compat, but ignored
|
||||
async_: bool = False, # noqa: ARG002 arg kept for backwards compat, but ignored
|
||||
) -> AgentExecutorIterator:
|
||||
"""Enables iteration over steps taken to reach final output.
|
||||
|
||||
|
@ -13,6 +13,7 @@ from langchain_core.prompts.chat import (
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
@ -65,6 +66,7 @@ class ChatAgent(Agent):
|
||||
return agent_scratchpad
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return ChatOutputParser()
|
||||
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
@ -35,6 +36,7 @@ class ConversationalAgent(Agent):
|
||||
"""Output parser for the agent."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(
|
||||
cls,
|
||||
ai_prefix: str = "AI",
|
||||
|
@ -20,6 +20,7 @@ from langchain_core.prompts.chat import (
|
||||
)
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.agent import Agent, AgentOutputParser
|
||||
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
|
||||
@ -42,6 +43,7 @@ class ConversationalChatAgent(Agent):
|
||||
"""Template for the tool response."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return ConvoOutputParser()
|
||||
|
||||
|
@ -12,6 +12,7 @@ from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from langchain_core.tools.render import render_text_description
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
@ -51,6 +52,7 @@ class ZeroShotAgent(Agent):
|
||||
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return MRKLOutputParser()
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import Any
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.format_scratchpad import (
|
||||
format_to_openai_function_messages,
|
||||
@ -55,6 +56,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer.
|
||||
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
|
||||
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return ReActOutputParser()
|
||||
|
||||
@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
|
||||
return AgentType.REACT_DOCSTORE
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
"""Return default prompt."""
|
||||
return WIKI_PROMPT
|
||||
@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
||||
"""Agent for the ReAct TextWorld chain."""
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
"""Return default prompt."""
|
||||
return TEXTWORLD_PROMPT
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||
from langchain_core.tools import BaseTool, Tool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||
from langchain.agents.agent_types import AgentType
|
||||
@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
|
||||
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||
return SelfAskOutputParser()
|
||||
|
||||
@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
|
||||
return AgentType.SELF_ASK_WITH_SEARCH
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||
"""Prompt does not depend on tools."""
|
||||
return PROMPT
|
||||
|
@ -71,6 +71,7 @@ class StructuredChatAgent(Agent):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def _get_default_output_parser(
|
||||
cls,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain_core.tools import BaseTool, tool
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class InvalidTool(BaseTool):
|
||||
@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
|
||||
description: str = "Called when tool name is invalid. Suggests valid tool names."
|
||||
"""Description of the tool."""
|
||||
|
||||
@override
|
||||
def _run(
|
||||
self,
|
||||
requested_tool_name: str,
|
||||
@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
|
||||
f"try one of [{available_tool_names_str}]."
|
||||
)
|
||||
|
||||
@override
|
||||
async def _arun(
|
||||
self,
|
||||
requested_tool_name: str,
|
||||
|
@ -4,6 +4,7 @@ import sys
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
||||
from typing_extensions import override
|
||||
|
||||
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
||||
|
||||
@ -63,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
self.stream_prefix = stream_prefix
|
||||
self.answer_reached = False
|
||||
|
||||
@override
|
||||
def on_llm_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@ -72,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
||||
"""Run when LLM starts running."""
|
||||
self.answer_reached = False
|
||||
|
||||
@override
|
||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||
|
||||
|
@ -388,7 +388,7 @@ except ImportError:
|
||||
class APIChain: # type: ignore[no-redef]
|
||||
"""Raise an ImportError if APIChain is used without langchain_community."""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
def __init__(self, *_: Any, **__: Any) -> None:
|
||||
"""Raise an ImportError if APIChain is used without langchain_community."""
|
||||
msg = (
|
||||
"To use the APIChain, you must install the langchain_community package."
|
||||
|
@ -83,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
||||
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
This can be used by a caller to determine whether passing in a list
|
||||
|
@ -402,6 +402,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
@override
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
@ -416,6 +417,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
||||
)
|
||||
return self._reduce_tokens_below_limit(docs)
|
||||
|
||||
@override
|
||||
async def _aget_docs(
|
||||
self,
|
||||
question: str,
|
||||
@ -512,6 +514,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
||||
)
|
||||
return values
|
||||
|
||||
@override
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
|
@ -1,4 +1,4 @@
|
||||
def __getattr__(name: str = "") -> None:
|
||||
def __getattr__(_: str = "") -> None:
|
||||
"""Raise an error on import since is deprecated."""
|
||||
msg = (
|
||||
"This module has been moved to langchain-experimental. "
|
||||
|
@ -1,4 +1,4 @@
|
||||
def __getattr__(name: str = "") -> None:
|
||||
def __getattr__(_: str = "") -> None:
|
||||
"""Raise an error on import since is deprecated."""
|
||||
msg = (
|
||||
"This module has been moved to langchain-experimental. "
|
||||
|
@ -39,7 +39,7 @@ try:
|
||||
from langchain_community.llms.loading import load_llm, load_llm_from_config
|
||||
except ImportError:
|
||||
|
||||
def load_llm(*args: Any, **kwargs: Any) -> None:
|
||||
def load_llm(*_: Any, **__: Any) -> None:
|
||||
"""Import error for load_llm."""
|
||||
msg = (
|
||||
"To use this load_llm functionality you must install the "
|
||||
@ -48,7 +48,7 @@ except ImportError:
|
||||
)
|
||||
raise ImportError(msg)
|
||||
|
||||
def load_llm_from_config(*args: Any, **kwargs: Any) -> None:
|
||||
def load_llm_from_config(*_: Any, **__: Any) -> None:
|
||||
"""Import error for load_llm_from_config."""
|
||||
msg = (
|
||||
"To use this load_llm_from_config functionality you must install the "
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks import (
|
||||
)
|
||||
from langchain_core.utils import check_package_version, get_from_dict_or_env
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
@ -105,6 +106,7 @@ class OpenAIModerationChain(Chain):
|
||||
return error_str
|
||||
return text
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
|
@ -16,6 +16,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains import ReduceDocumentsChain
|
||||
from langchain.chains.base import Chain
|
||||
@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""
|
||||
return [self.input_docs_key, self.question_key]
|
||||
|
||||
@override
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
"""Get docs to run questioning over."""
|
||||
return inputs.pop(self.input_docs_key)
|
||||
|
||||
@override
|
||||
async def _aget_docs(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||
@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
||||
|
||||
return docs[:num_docs]
|
||||
|
||||
@override
|
||||
def _get_docs(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
|
@ -11,7 +11,7 @@ try:
|
||||
from lark import Lark, Transformer, v_args
|
||||
except ImportError:
|
||||
|
||||
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
|
||||
def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
|
||||
"""Dummy decorator for when lark is not installed."""
|
||||
return lambda _: None
|
||||
|
||||
|
@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||
@ -330,6 +331,7 @@ class VectorDBQA(BaseRetrievalQA):
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@override
|
||||
def _get_docs(
|
||||
self,
|
||||
question: str,
|
||||
|
@ -11,6 +11,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.router.base import RouterChain
|
||||
|
||||
@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
|
||||
"""
|
||||
return self.routing_keys
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
|
||||
results = self.vectorstore.similarity_search(_input, k=1)
|
||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
@ -63,6 +64,7 @@ class TransformChain(Chain):
|
||||
"""
|
||||
return self.output_variables
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
@ -70,6 +72,7 @@ class TransformChain(Chain):
|
||||
) -> dict[str, str]:
|
||||
return self.transform_cb(inputs)
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
|
@ -331,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
"""
|
||||
return ["prediction", "reference"]
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -355,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -382,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
|
||||
@override
|
||||
def _evaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -416,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_strings(
|
||||
self,
|
||||
*,
|
||||
@ -478,6 +482,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
"""Return the evaluation name."""
|
||||
return f"pairwise_embedding_{self.distance_metric.value}_distance"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -505,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -532,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
score = self._compute_score(vectors)
|
||||
return {"score": score}
|
||||
|
||||
@override
|
||||
def _evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
@ -567,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
|
@ -1,6 +1,8 @@
|
||||
import string
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
|
||||
@ -78,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
||||
"""
|
||||
return "exact_match"
|
||||
|
||||
@override
|
||||
def _evaluate_strings( # type: ignore[override]
|
||||
self,
|
||||
*,
|
||||
|
@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
def __init__(self, **_: Any) -> None:
|
||||
"""Initializes the JsonSchemaEvaluator.
|
||||
|
||||
Args:
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ImportError: If the jsonschema package is not installed.
|
||||
"""
|
||||
|
@ -1,6 +1,8 @@
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
|
||||
|
||||
@ -70,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
||||
"""
|
||||
return "regex_match"
|
||||
|
||||
@override
|
||||
def _evaluate_strings( # type: ignore[override]
|
||||
self,
|
||||
*,
|
||||
|
@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
"""
|
||||
return f"{self.distance.value}_distance"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
||||
"""
|
||||
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"""
|
||||
return f"pairwise_{self.distance.value}_distance"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||
}
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, Any],
|
||||
@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||
}
|
||||
|
||||
@override
|
||||
def _evaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
||||
)
|
||||
return self._prepare_output(result)
|
||||
|
||||
@override
|
||||
async def _aevaluate_string_pairs(
|
||||
self,
|
||||
*,
|
||||
|
@ -79,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
||||
@override
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.abuffer()
|
||||
@ -133,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
@ -2,6 +2,7 @@ from typing import Any, Union
|
||||
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
@ -133,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
if self.return_messages:
|
||||
|
@ -3,6 +3,7 @@ from typing import Any, Union
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from langchain_core.utils import pre_init
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.memory.summary import SummarizerMixin
|
||||
@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
buffer = self.chat_memory.messages
|
||||
@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
||||
)
|
||||
return {self.memory_key: final_buffer}
|
||||
|
||||
@override
|
||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Asynchronously return key-value pairs given the text input to the chain."""
|
||||
buffer = await self.chat_memory.aget_messages()
|
||||
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
from langchain_core._api import deprecated
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
|
||||
@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
|
||||
"""
|
||||
return [self.memory_key]
|
||||
|
||||
@override
|
||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Return history buffer."""
|
||||
return {self.memory_key: self.buffer}
|
||||
|
@ -110,7 +110,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[PromptTemplate] = None,
|
||||
get_input: Optional[Callable[[str, Document], str]] = None,
|
||||
llm_chain_kwargs: Optional[dict] = None,
|
||||
llm_chain_kwargs: Optional[dict] = None, # noqa: ARG003
|
||||
) -> LLMChainExtractor:
|
||||
"""Initialize from LLM."""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
|
@ -9,6 +9,7 @@ from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.utils import get_from_dict_or_env
|
||||
from pydantic import ConfigDict, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@deprecated(
|
||||
@ -98,6 +99,7 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
for res in results
|
||||
]
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
|
||||
|
||||
@ -25,6 +26,7 @@ class CrossEncoderReranker(BaseDocumentCompressor):
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
|
@ -6,6 +6,7 @@ from langchain_core.documents import BaseDocumentCompressor, Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.utils import pre_init
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def _get_similarity_function() -> Callable:
|
||||
@ -50,6 +51,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
raise ValueError(msg)
|
||||
return values
|
||||
|
||||
@override
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -93,6 +95,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||
return [stateful_documents[i] for i in included_idxs]
|
||||
|
||||
@override
|
||||
async def acompress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
|
@ -67,7 +67,7 @@ class MultiQueryRetriever(BaseRetriever):
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||
parser_key: Optional[str] = None,
|
||||
parser_key: Optional[str] = None, # noqa: ARG003
|
||||
include_original: bool = False, # noqa: FBT001,FBT002
|
||||
) -> "MultiQueryRetriever":
|
||||
"""Initialize from llm using default template.
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.stores import BaseStore, ByteStore
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.storage._lc_store import create_kv_docstore
|
||||
|
||||
@ -54,6 +55,7 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
values["docstore"] = docstore
|
||||
return values
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@ -91,6 +93,7 @@ class MultiVectorRetriever(BaseRetriever):
|
||||
docs = self.docstore.mget(ids)
|
||||
return [d for d in docs if d is not None]
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from pydantic import ConfigDict, Field
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
|
||||
@ -128,6 +129,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||
result.append(buffered_doc)
|
||||
return result
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
@ -142,6 +144,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||
docs_and_scores.update(self.get_salient_docs(query))
|
||||
return self._get_rescored_docs(docs_and_scores)
|
||||
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -19,7 +19,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
||||
total: int,
|
||||
ncols: int = 50,
|
||||
end_with: str = "\n",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize the progress bar.
|
||||
|
||||
|
@ -355,6 +355,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
feedback.evaluator_info[RUN_KEY] = output[RUN_KEY]
|
||||
return feedback
|
||||
|
||||
@override
|
||||
def evaluate_run(
|
||||
self,
|
||||
run: Run,
|
||||
@ -372,6 +373,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
# TODO: Add run ID once we can declare it via callbacks
|
||||
)
|
||||
|
||||
@override
|
||||
async def aevaluate_run(
|
||||
self,
|
||||
run: Run,
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
|
||||
def __getattr__(name: str = "") -> Any:
|
||||
def __getattr__(_: str = "") -> Any:
|
||||
msg = (
|
||||
"This tool has been moved to langchain experiment. "
|
||||
"This tool has access to a python REPL. "
|
||||
|
@ -145,6 +145,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"A", # flake8-builtins
|
||||
"ARG", # flake8-unused-arguments
|
||||
"ASYNC", # flake8-async
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
|
@ -3,6 +3,7 @@
|
||||
import math
|
||||
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from typing_extensions import override
|
||||
|
||||
fake_texts = ["foo", "bar", "baz"]
|
||||
|
||||
@ -18,6 +19,7 @@ class FakeEmbeddings(Embeddings):
|
||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
return self.embed_documents(texts)
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return constant query embeddings.
|
||||
Embeddings are identical to embed_documents(texts)[0].
|
||||
|
@ -25,6 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables.utils import add
|
||||
from langchain_core.tools import Tool, tool
|
||||
from langchain_core.tracers import RunLog, RunLogPatch
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents import (
|
||||
AgentExecutor,
|
||||
@ -48,6 +49,7 @@ class FakeListLLM(LLM):
|
||||
responses: list[str]
|
||||
i: int = -1
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -462,7 +464,7 @@ async def test_runnable_agent() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
|
||||
|
||||
@ -569,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||
|
||||
@ -681,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
||||
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||
"""A parser."""
|
||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||
|
||||
@ -1032,7 +1034,7 @@ async def test_openai_agent_tools_agent() -> None:
|
||||
],
|
||||
)
|
||||
|
||||
GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc]
|
||||
GenericFakeChatModel.bind_tools = lambda self, _: self # type: ignore[assignment,misc]
|
||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||
|
||||
@tool
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langchain_core.runnables.utils import add
|
||||
from langchain_core.tools import Tool
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
@ -19,6 +20,7 @@ class FakeListLLM(LLM):
|
||||
responses: list[str]
|
||||
i: int = -1
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -364,7 +364,7 @@ def test_agent_iterator_failing_tool() -> None:
|
||||
tools = [
|
||||
Tool(
|
||||
name="FailingTool",
|
||||
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
|
||||
func=lambda _: 1 / 0, # This tool will raise a ZeroDivisionError
|
||||
description="A tool that fails",
|
||||
),
|
||||
]
|
||||
|
@ -8,7 +8,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
||||
|
||||
|
||||
@tool
|
||||
def my_tool(query: str) -> str:
|
||||
def my_tool(query: str) -> str: # noqa: ARG001
|
||||
"""A fake tool."""
|
||||
return "fake tool"
|
||||
|
||||
|
@ -141,11 +141,11 @@ def test_valid_action_and_answer_raises_exception() -> None:
|
||||
def test_from_chains() -> None:
|
||||
"""Test initializing from chains."""
|
||||
chain_configs = [
|
||||
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
|
||||
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
|
||||
Tool(name="foo", func=lambda _x: "foo", description="foobar1"),
|
||||
Tool(name="bar", func=lambda _x: "bar", description="foobar2"),
|
||||
]
|
||||
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
|
||||
expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2"
|
||||
expected_tools_prompt = "foo(_x) - foobar1\nbar(_x) - foobar2"
|
||||
expected_tool_names = "foo, bar"
|
||||
expected_template = "\n\n".join(
|
||||
[
|
||||
|
@ -7,7 +7,7 @@ import pytest
|
||||
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
||||
|
||||
|
||||
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
|
||||
def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:
|
||||
client = AsyncMock() if use_async else MagicMock()
|
||||
mock_assistant = MagicMock()
|
||||
mock_assistant.id = "abc123"
|
||||
|
@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class BaseFakeCallbackHandler(BaseModel):
|
||||
@ -135,6 +136,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
"""Whether to ignore retriever callbacks."""
|
||||
return self.ignore_retriever_
|
||||
|
||||
@override
|
||||
def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -142,6 +144,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_start_common()
|
||||
|
||||
@override
|
||||
def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -149,6 +152,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
@override
|
||||
def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -156,6 +160,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_end_common()
|
||||
|
||||
@override
|
||||
def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -163,6 +168,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_llm_error_common()
|
||||
|
||||
@override
|
||||
def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -170,6 +176,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
@override
|
||||
def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -177,6 +184,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_start_common()
|
||||
|
||||
@override
|
||||
def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -184,6 +192,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_end_common()
|
||||
|
||||
@override
|
||||
def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -191,6 +200,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_chain_error_common()
|
||||
|
||||
@override
|
||||
def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -198,6 +208,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_start_common()
|
||||
|
||||
@override
|
||||
def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -205,6 +216,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_end_common()
|
||||
|
||||
@override
|
||||
def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -212,6 +224,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_tool_error_common()
|
||||
|
||||
@override
|
||||
def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -219,6 +232,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_agent_action_common()
|
||||
|
||||
@override
|
||||
def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -226,6 +240,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
@override
|
||||
def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -233,6 +248,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_text_common()
|
||||
|
||||
@override
|
||||
def on_retriever_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -240,6 +256,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_start_common()
|
||||
|
||||
@override
|
||||
def on_retriever_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -247,6 +264,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
) -> Any:
|
||||
self.on_retriever_end_common()
|
||||
|
||||
@override
|
||||
def on_retriever_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -259,6 +277,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
||||
|
||||
|
||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||
@override
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@ -290,6 +309,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
"""Whether to ignore agent callbacks."""
|
||||
return self.ignore_agent_
|
||||
|
||||
@override
|
||||
async def on_retry(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -297,6 +317,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> Any:
|
||||
self.on_retry_common()
|
||||
|
||||
@override
|
||||
async def on_llm_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -304,6 +325,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_start_common()
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -311,6 +333,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_new_token_common()
|
||||
|
||||
@override
|
||||
async def on_llm_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -318,6 +341,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_end_common()
|
||||
|
||||
@override
|
||||
async def on_llm_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -325,6 +349,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_llm_error_common()
|
||||
|
||||
@override
|
||||
async def on_chain_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -332,6 +357,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_start_common()
|
||||
|
||||
@override
|
||||
async def on_chain_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -339,6 +365,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_end_common()
|
||||
|
||||
@override
|
||||
async def on_chain_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -346,6 +373,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_chain_error_common()
|
||||
|
||||
@override
|
||||
async def on_tool_start(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -353,6 +381,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_start_common()
|
||||
|
||||
@override
|
||||
async def on_tool_end(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -360,6 +389,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_end_common()
|
||||
|
||||
@override
|
||||
async def on_tool_error(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -367,6 +397,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_tool_error_common()
|
||||
|
||||
@override
|
||||
async def on_agent_action(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -374,6 +405,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_agent_action_common()
|
||||
|
||||
@override
|
||||
async def on_agent_finish(
|
||||
self,
|
||||
*args: Any,
|
||||
@ -381,6 +413,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
||||
) -> None:
|
||||
self.on_agent_finish_common()
|
||||
|
||||
@override
|
||||
async def on_text(
|
||||
self,
|
||||
*args: Any,
|
||||
|
@ -3,6 +3,7 @@ import re
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.chains.base import Chain
|
||||
@ -25,6 +26,7 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
|
@ -2,6 +2,7 @@ from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.callbacks import StdOutCallbackHandler
|
||||
from langchain.chains.base import Chain
|
||||
@ -24,6 +25,7 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
|
@ -7,6 +7,7 @@ import pytest
|
||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.tracers.context import collect_runs
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import RUN_KEY
|
||||
@ -21,6 +22,7 @@ class FakeMemory(BaseMemory):
|
||||
"""Return baz variable."""
|
||||
return ["baz"]
|
||||
|
||||
@override
|
||||
def load_memory_variables(
|
||||
self,
|
||||
inputs: Optional[dict[str, Any]] = None,
|
||||
@ -52,6 +54,7 @@ class FakeChain(Chain):
|
||||
"""Output key of bar."""
|
||||
return self.the_output_keys
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
|
@ -19,7 +19,7 @@ def _fake_docs_len_func(docs: list[Document]) -> int:
|
||||
return len(_fake_combine_docs_func(docs))
|
||||
|
||||
|
||||
def _fake_combine_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
||||
def _fake_combine_docs_func(docs: list[Document], **_: Any) -> str:
|
||||
return "".join([d.page_content for d in docs])
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import LLM
|
||||
from langchain_core.memory import BaseMemory
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.memory.buffer import ConversationBufferMemory
|
||||
@ -26,6 +27,7 @@ class DummyLLM(LLM):
|
||||
def _llm_type(self) -> str:
|
||||
return "dummy"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models.llms import BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
@ -18,10 +19,12 @@ from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||
class FakeEmbeddings(Embeddings):
|
||||
"""Fake embedding class for tests."""
|
||||
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Return random floats."""
|
||||
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
"""Return random floats."""
|
||||
return list(np.random.uniform(0, 1, 10))
|
||||
@ -32,6 +35,7 @@ class FakeLLM(BaseLLM):
|
||||
|
||||
n: int = 1
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
@ -41,6 +45,7 @@ class FakeLLM(BaseLLM):
|
||||
) -> LLMResult:
|
||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
prompts: list[str],
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
CallbackManagerForChainRun,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
@ -32,6 +33,7 @@ class FakeChain(Chain):
|
||||
"""Input keys this chain returns."""
|
||||
return self.output_variables
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
@ -43,6 +45,7 @@ class FakeChain(Chain):
|
||||
outputs[var] = f"{' '.join(variables)}foo"
|
||||
return outputs
|
||||
|
||||
@override
|
||||
async def _acall(
|
||||
self,
|
||||
inputs: dict[str, str],
|
||||
|
@ -4,6 +4,7 @@ from collections.abc import Iterator
|
||||
|
||||
from langchain_core.document_loaders import BaseBlobParser, Blob
|
||||
from langchain_core.documents import Document
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def test_base_blob_parser() -> None:
|
||||
@ -12,6 +13,7 @@ def test_base_blob_parser() -> None:
|
||||
class MyParser(BaseBlobParser):
|
||||
"""A simple parser that returns a single document."""
|
||||
|
||||
@override
|
||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||
"""Lazy parsing interface."""
|
||||
yield Document(
|
||||
|
@ -7,12 +7,14 @@ import warnings
|
||||
|
||||
import pytest
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.embeddings import CacheBackedEmbeddings
|
||||
from langchain.storage.in_memory import InMemoryStore
|
||||
|
||||
|
||||
class MockEmbeddings(Embeddings):
|
||||
@override
|
||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||
# Simulate embedding documents
|
||||
embeddings: list[list[float]] = []
|
||||
@ -23,6 +25,7 @@ class MockEmbeddings(Embeddings):
|
||||
embeddings.append([len(text), len(text) + 1])
|
||||
return embeddings
|
||||
|
||||
@override
|
||||
def embed_query(self, text: str) -> list[float]:
|
||||
# Simulate embedding a query
|
||||
return [5.0, 6.0]
|
||||
|
@ -9,6 +9,7 @@ from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.tools import tool
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.evaluation.agents.trajectory_eval_chain import (
|
||||
TrajectoryEval,
|
||||
@ -43,6 +44,7 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
||||
sequential_responses: Optional[bool] = False
|
||||
response_index: int = 0
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@ -13,6 +13,7 @@ from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.indexing.api import _abatch, _get_document_with_hash
|
||||
from langchain_core.vectorstores import VST, VectorStore
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.indexes import aindex, index
|
||||
from langchain.indexes._sql_record_manager import SQLRecordManager
|
||||
@ -45,18 +46,21 @@ class InMemoryVectorStore(VectorStore):
|
||||
self.store: dict[str, Document] = {}
|
||||
self.permit_upserts = permit_upserts
|
||||
|
||||
@override
|
||||
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete the given documents from the store using their IDs."""
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
@override
|
||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||
"""Delete the given documents from the store using their IDs."""
|
||||
if ids:
|
||||
for _id in ids:
|
||||
self.store.pop(_id, None)
|
||||
|
||||
@override
|
||||
def add_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
@ -81,6 +85,7 @@ class InMemoryVectorStore(VectorStore):
|
||||
|
||||
return list(ids)
|
||||
|
||||
@override
|
||||
async def aadd_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
|
@ -16,11 +16,13 @@ from langchain_core.messages import (
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class FakeChatModel(SimpleChatModel):
|
||||
"""Fake Chat Model wrapper for testing purposes."""
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -30,6 +32,7 @@ class FakeChatModel(SimpleChatModel):
|
||||
) -> str:
|
||||
return "fake response"
|
||||
|
||||
@override
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
@ -74,6 +77,7 @@ class GenericFakeChatModel(BaseChatModel):
|
||||
into message chunks.
|
||||
"""
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@ -6,6 +6,7 @@ from typing import Any, Optional, cast
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from pydantic import model_validator
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class FakeLLM(LLM):
|
||||
@ -32,6 +33,7 @@ class FakeLLM(LLM):
|
||||
"""Return type of llm."""
|
||||
return "fake"
|
||||
|
||||
@override
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
|
@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||
from typing_extensions import override
|
||||
|
||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||
@ -166,6 +167,7 @@ async def test_callback_handlers() -> None:
|
||||
# Required to implement since this is an abstract method
|
||||
pass
|
||||
|
||||
@override
|
||||
async def on_llm_new_token(
|
||||
self,
|
||||
token: str,
|
||||
|
@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||
@ -21,6 +22,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
parse_count: int = 0 # Number of times parse has been called
|
||||
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||
|
||||
@override
|
||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||
self.parse_count += 1
|
||||
if self.parse_count <= self.attemp_count_before_success:
|
||||
@ -62,7 +64,7 @@ def test_output_fixing_parser_parse(
|
||||
|
||||
|
||||
def test_output_fixing_parser_from_llm() -> None:
|
||||
def fake_llm(prompt: str) -> AIMessage:
|
||||
def fake_llm(_: str) -> AIMessage:
|
||||
return AIMessage("2024-07-08T00:00:00.000000Z")
|
||||
|
||||
llm = RunnableLambda(fake_llm)
|
||||
|
@ -7,6 +7,7 @@ from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||
@ -25,6 +26,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
||||
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||
error_msg: str = "error"
|
||||
|
||||
@override
|
||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||
self.parse_count += 1
|
||||
if self.parse_count <= self.attemp_count_before_success:
|
||||
|
@ -14,6 +14,7 @@ from langchain_core.structured_query import (
|
||||
StructuredQuery,
|
||||
Visitor,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||
from langchain.retrievers import SelfQueryRetriever
|
||||
@ -61,6 +62,7 @@ class FakeTranslator(Visitor):
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
@override
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -1,5 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
class SequentialRetriever(BaseRetriever):
|
||||
@ -8,17 +11,21 @@ class SequentialRetriever(BaseRetriever):
|
||||
sequential_responses: list[list[Document]]
|
||||
response_index: int = 0
|
||||
|
||||
def _get_relevant_documents( # type: ignore[override]
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
if self.response_index >= len(self.sequential_responses):
|
||||
return []
|
||||
self.response_index += 1
|
||||
return self.sequential_responses[self.response_index - 1]
|
||||
|
||||
async def _aget_relevant_documents( # type: ignore[override]
|
||||
@override
|
||||
async def _aget_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
**kwargs: Any,
|
||||
) -> list[Document]:
|
||||
return self._get_relevant_documents(query)
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
|
||||
@ -10,6 +11,7 @@ from langchain.retrievers.ensemble import EnsembleRetriever
|
||||
class MockRetriever(BaseRetriever):
|
||||
docs: list[Document]
|
||||
|
||||
@override
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Any, Callable
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
|
||||
from langchain.storage import InMemoryStore
|
||||
@ -15,6 +16,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._identity_fn
|
||||
|
||||
@override
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -26,6 +28,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
return []
|
||||
return [res]
|
||||
|
||||
@override
|
||||
def similarity_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -3,6 +3,7 @@ from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_text_splitters.character import CharacterTextSplitter
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.retrievers import ParentDocumentRetriever
|
||||
from langchain.storage import InMemoryStore
|
||||
@ -10,6 +11,7 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
||||
|
||||
|
||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
@override
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -21,6 +23,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||
return []
|
||||
return [res]
|
||||
|
||||
@override
|
||||
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]:
|
||||
print(documents) # noqa: T201
|
||||
return super().add_documents(
|
||||
|
@ -8,6 +8,7 @@ import pytest
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.retrievers.time_weighted_retriever import (
|
||||
TimeWeightedVectorStoreRetriever,
|
||||
@ -31,6 +32,7 @@ def _get_example_memories(k: int = 4) -> list[Document]:
|
||||
class MockVectorStore(VectorStore):
|
||||
"""Mock invalid vector store."""
|
||||
|
||||
@override
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -39,6 +41,7 @@ class MockVectorStore(VectorStore):
|
||||
) -> list[str]:
|
||||
return list(texts)
|
||||
|
||||
@override
|
||||
def similarity_search(
|
||||
self,
|
||||
query: str,
|
||||
@ -48,6 +51,7 @@ class MockVectorStore(VectorStore):
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def from_texts(
|
||||
cls: type["MockVectorStore"],
|
||||
texts: list[str],
|
||||
@ -57,6 +61,7 @@ class MockVectorStore(VectorStore):
|
||||
) -> "MockVectorStore":
|
||||
return cls()
|
||||
|
||||
@override
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
|
@ -38,7 +38,7 @@ repo_dict = {
|
||||
}
|
||||
|
||||
|
||||
def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate:
|
||||
def repo_lookup(owner_repo_commit: str, **_: Any) -> ChatPromptTemplate:
|
||||
return repo_dict[owner_repo_commit]
|
||||
|
||||
|
||||
|
@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from pytest_mock import MockerFixture
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
from typing_extensions import override
|
||||
|
||||
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||
|
||||
@ -15,6 +16,7 @@ class FakeChatOpenAI(BaseChatModel):
|
||||
def _llm_type(self) -> str:
|
||||
return "fake-openai-chat-model"
|
||||
|
||||
@override
|
||||
def _generate(
|
||||
self,
|
||||
messages: list[BaseMessage],
|
||||
|
@ -3,16 +3,14 @@
|
||||
import uuid
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langsmith.client import Client
|
||||
from langsmith.schemas import Dataset, Example
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.transform import TransformChain
|
||||
from langchain.smith.evaluation.runner_utils import (
|
||||
InputFormatError,
|
||||
@ -243,7 +241,7 @@ def test_run_chat_model_all_formats(inputs: dict[str, Any]) -> None:
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
async def test_arun_on_dataset() -> None:
|
||||
dataset = Dataset(
|
||||
id=uuid.uuid4(),
|
||||
name="test",
|
||||
@ -298,22 +296,20 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
),
|
||||
]
|
||||
|
||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
||||
def mock_read_dataset(*_: Any, **__: Any) -> Dataset:
|
||||
return dataset
|
||||
|
||||
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
|
||||
def mock_list_examples(*_: Any, **__: Any) -> Iterator[Example]:
|
||||
return iter(examples)
|
||||
|
||||
async def mock_arun_chain(
|
||||
example: Example,
|
||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
||||
tags: Optional[list[str]] = None,
|
||||
callbacks: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
*_: Any,
|
||||
**__: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {"result": f"Result for example {example.id}"}
|
||||
|
||||
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
|
||||
def mock_create_project(*_: Any, **__: Any) -> Any:
|
||||
proj = mock.MagicMock()
|
||||
proj.id = "123"
|
||||
return proj
|
||||
|
@ -8,13 +8,13 @@ from langchain.tools.render import (
|
||||
|
||||
|
||||
@tool
|
||||
def search(query: str) -> str:
|
||||
def search(query: str) -> str: # noqa: ARG001
|
||||
"""Lookup things online."""
|
||||
return "foo"
|
||||
|
||||
|
||||
@tool
|
||||
def calculator(expression: str) -> str:
|
||||
def calculator(expression: str) -> str: # noqa: ARG001
|
||||
"""Do math."""
|
||||
return "bar"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user