mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 05:20:39 +00:00
chore(langchain): add ruff rules ARG (#32110)
See https://docs.astral.sh/ruff/rules/#flake8-unused-arguments-arg Co-authored-by: Mason Daugherty <mason@langchain.dev>
This commit is contained in:
parent
a2ad5aca41
commit
efdfa00d10
@ -116,8 +116,8 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
|
||||||
**kwargs: Any,
|
**_: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
|
|
||||||
@ -125,7 +125,6 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
early_stopping_method: Method to use for early stopping.
|
early_stopping_method: Method to use for early stopping.
|
||||||
intermediate_steps: Steps the LLM has taken to date,
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
along with observations.
|
along with observations.
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentFinish: Agent finish object.
|
AgentFinish: Agent finish object.
|
||||||
@ -168,6 +167,7 @@ class BaseSingleActionAgent(BaseModel):
|
|||||||
"""Return Identifier of an agent type."""
|
"""Return Identifier of an agent type."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent.
|
"""Return dictionary representation of agent.
|
||||||
|
|
||||||
@ -289,8 +289,8 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
def return_stopped_response(
|
def return_stopped_response(
|
||||||
self,
|
self,
|
||||||
early_stopping_method: str,
|
early_stopping_method: str,
|
||||||
intermediate_steps: list[tuple[AgentAction, str]],
|
intermediate_steps: list[tuple[AgentAction, str]], # noqa: ARG002
|
||||||
**kwargs: Any,
|
**_: Any,
|
||||||
) -> AgentFinish:
|
) -> AgentFinish:
|
||||||
"""Return response when agent has been stopped due to max iterations.
|
"""Return response when agent has been stopped due to max iterations.
|
||||||
|
|
||||||
@ -298,7 +298,6 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
early_stopping_method: Method to use for early stopping.
|
early_stopping_method: Method to use for early stopping.
|
||||||
intermediate_steps: Steps the LLM has taken to date,
|
intermediate_steps: Steps the LLM has taken to date,
|
||||||
along with observations.
|
along with observations.
|
||||||
**kwargs: User inputs.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AgentFinish: Agent finish object.
|
AgentFinish: Agent finish object.
|
||||||
@ -317,6 +316,7 @@ class BaseMultiActionAgent(BaseModel):
|
|||||||
"""Return Identifier of an agent type."""
|
"""Return Identifier of an agent type."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().model_dump()
|
_dict = super().model_dump()
|
||||||
@ -651,6 +651,7 @@ class LLMSingleActionAgent(BaseSingleActionAgent):
|
|||||||
"""
|
"""
|
||||||
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
return list(set(self.llm_chain.input_keys) - {"intermediate_steps"})
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
@ -735,6 +736,7 @@ class Agent(BaseSingleActionAgent):
|
|||||||
allowed_tools: Optional[list[str]] = None
|
allowed_tools: Optional[list[str]] = None
|
||||||
"""Allowed tools for the agent. If None, all tools are allowed."""
|
"""Allowed tools for the agent. If None, all tools are allowed."""
|
||||||
|
|
||||||
|
@override
|
||||||
def dict(self, **kwargs: Any) -> builtins.dict:
|
def dict(self, **kwargs: Any) -> builtins.dict:
|
||||||
"""Return dictionary representation of agent."""
|
"""Return dictionary representation of agent."""
|
||||||
_dict = super().dict()
|
_dict = super().dict()
|
||||||
@ -750,18 +752,6 @@ class Agent(BaseSingleActionAgent):
|
|||||||
"""Return values of the agent."""
|
"""Return values of the agent."""
|
||||||
return ["output"]
|
return ["output"]
|
||||||
|
|
||||||
def _fix_text(self, text: str) -> str:
|
|
||||||
"""Fix the text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Text to fix.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Fixed text.
|
|
||||||
"""
|
|
||||||
msg = "fix_text not implemented for this agent."
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _stop(self) -> list[str]:
|
def _stop(self) -> list[str]:
|
||||||
return [
|
return [
|
||||||
@ -1021,6 +1011,7 @@ class ExceptionTool(BaseTool):
|
|||||||
description: str = "Exception tool"
|
description: str = "Exception tool"
|
||||||
"""Description of the tool."""
|
"""Description of the tool."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -1028,6 +1019,7 @@ class ExceptionTool(BaseTool):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
@override
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -1188,6 +1180,7 @@ class AgentExecutor(Chain):
|
|||||||
return cast("RunnableAgentType", self.agent)
|
return cast("RunnableAgentType", self.agent)
|
||||||
return self.agent
|
return self.agent
|
||||||
|
|
||||||
|
@override
|
||||||
def save(self, file_path: Union[Path, str]) -> None:
|
def save(self, file_path: Union[Path, str]) -> None:
|
||||||
"""Raise error - saving not supported for Agent Executors.
|
"""Raise error - saving not supported for Agent Executors.
|
||||||
|
|
||||||
@ -1218,7 +1211,7 @@ class AgentExecutor(Chain):
|
|||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
*,
|
*,
|
||||||
include_run_info: bool = False,
|
include_run_info: bool = False,
|
||||||
async_: bool = False, # arg kept for backwards compat, but ignored
|
async_: bool = False, # noqa: ARG002 arg kept for backwards compat, but ignored
|
||||||
) -> AgentExecutorIterator:
|
) -> AgentExecutorIterator:
|
||||||
"""Enables iteration over steps taken to reach final output.
|
"""Enables iteration over steps taken to reach final output.
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ from langchain_core.prompts.chat import (
|
|||||||
)
|
)
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||||
from langchain.agents.agent import Agent, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentOutputParser
|
||||||
@ -65,6 +66,7 @@ class ChatAgent(Agent):
|
|||||||
return agent_scratchpad
|
return agent_scratchpad
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return ChatOutputParser()
|
return ChatOutputParser()
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||||
from langchain.agents.agent import Agent, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentOutputParser
|
||||||
@ -35,6 +36,7 @@ class ConversationalAgent(Agent):
|
|||||||
"""Output parser for the agent."""
|
"""Output parser for the agent."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(
|
def _get_default_output_parser(
|
||||||
cls,
|
cls,
|
||||||
ai_prefix: str = "AI",
|
ai_prefix: str = "AI",
|
||||||
|
@ -20,6 +20,7 @@ from langchain_core.prompts.chat import (
|
|||||||
)
|
)
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentOutputParser
|
||||||
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
|
from langchain.agents.conversational_chat.output_parser import ConvoOutputParser
|
||||||
@ -42,6 +43,7 @@ class ConversationalChatAgent(Agent):
|
|||||||
"""Template for the tool response."""
|
"""Template for the tool response."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return ConvoOutputParser()
|
return ConvoOutputParser()
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from langchain_core.prompts import PromptTemplate
|
|||||||
from langchain_core.tools import BaseTool, Tool
|
from langchain_core.tools import BaseTool, Tool
|
||||||
from langchain_core.tools.render import render_text_description
|
from langchain_core.tools.render import render_text_description
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||||
@ -51,6 +52,7 @@ class ZeroShotAgent(Agent):
|
|||||||
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
|
output_parser: AgentOutputParser = Field(default_factory=MRKLOutputParser)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return MRKLOutputParser()
|
return MRKLOutputParser()
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.agents.format_scratchpad import (
|
from langchain.agents.format_scratchpad import (
|
||||||
format_to_openai_function_messages,
|
format_to_openai_function_messages,
|
||||||
@ -55,6 +56,7 @@ class AgentTokenBufferMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer.
|
"""Return history buffer.
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.language_models import BaseLanguageModel
|
|||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.tools import BaseTool, Tool
|
from langchain_core.tools import BaseTool, Tool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
from langchain._api.deprecation import AGENT_DEPRECATION_WARNING
|
||||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||||
@ -38,6 +39,7 @@ class ReActDocstoreAgent(Agent):
|
|||||||
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
|
output_parser: AgentOutputParser = Field(default_factory=ReActOutputParser)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return ReActOutputParser()
|
return ReActOutputParser()
|
||||||
|
|
||||||
@ -47,6 +49,7 @@ class ReActDocstoreAgent(Agent):
|
|||||||
return AgentType.REACT_DOCSTORE
|
return AgentType.REACT_DOCSTORE
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||||
"""Return default prompt."""
|
"""Return default prompt."""
|
||||||
return WIKI_PROMPT
|
return WIKI_PROMPT
|
||||||
@ -141,6 +144,7 @@ class ReActTextWorldAgent(ReActDocstoreAgent):
|
|||||||
"""Agent for the ReAct TextWorld chain."""
|
"""Agent for the ReAct TextWorld chain."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||||
"""Return default prompt."""
|
"""Return default prompt."""
|
||||||
return TEXTWORLD_PROMPT
|
return TEXTWORLD_PROMPT
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.prompts import BasePromptTemplate
|
|||||||
from langchain_core.runnables import Runnable, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnablePassthrough
|
||||||
from langchain_core.tools import BaseTool, Tool
|
from langchain_core.tools import BaseTool, Tool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
from langchain.agents.agent import Agent, AgentExecutor, AgentOutputParser
|
||||||
from langchain.agents.agent_types import AgentType
|
from langchain.agents.agent_types import AgentType
|
||||||
@ -32,6 +33,7 @@ class SelfAskWithSearchAgent(Agent):
|
|||||||
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
|
output_parser: AgentOutputParser = Field(default_factory=SelfAskOutputParser)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
def _get_default_output_parser(cls, **kwargs: Any) -> AgentOutputParser:
|
||||||
return SelfAskOutputParser()
|
return SelfAskOutputParser()
|
||||||
|
|
||||||
@ -41,6 +43,7 @@ class SelfAskWithSearchAgent(Agent):
|
|||||||
return AgentType.SELF_ASK_WITH_SEARCH
|
return AgentType.SELF_ASK_WITH_SEARCH
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
def create_prompt(cls, tools: Sequence[BaseTool]) -> BasePromptTemplate:
|
||||||
"""Prompt does not depend on tools."""
|
"""Prompt does not depend on tools."""
|
||||||
return PROMPT
|
return PROMPT
|
||||||
|
@ -71,6 +71,7 @@ class StructuredChatAgent(Agent):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def _get_default_output_parser(
|
def _get_default_output_parser(
|
||||||
cls,
|
cls,
|
||||||
llm: Optional[BaseLanguageModel] = None,
|
llm: Optional[BaseLanguageModel] = None,
|
||||||
|
@ -7,6 +7,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForToolRun,
|
CallbackManagerForToolRun,
|
||||||
)
|
)
|
||||||
from langchain_core.tools import BaseTool, tool
|
from langchain_core.tools import BaseTool, tool
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class InvalidTool(BaseTool):
|
class InvalidTool(BaseTool):
|
||||||
@ -17,6 +18,7 @@ class InvalidTool(BaseTool):
|
|||||||
description: str = "Called when tool name is invalid. Suggests valid tool names."
|
description: str = "Called when tool name is invalid. Suggests valid tool names."
|
||||||
"""Description of the tool."""
|
"""Description of the tool."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
requested_tool_name: str,
|
requested_tool_name: str,
|
||||||
@ -30,6 +32,7 @@ class InvalidTool(BaseTool):
|
|||||||
f"try one of [{available_tool_names_str}]."
|
f"try one of [{available_tool_names_str}]."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _arun(
|
async def _arun(
|
||||||
self,
|
self,
|
||||||
requested_tool_name: str,
|
requested_tool_name: str,
|
||||||
|
@ -4,6 +4,7 @@ import sys
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
from langchain_core.callbacks import StreamingStdOutCallbackHandler
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"]
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
self.stream_prefix = stream_prefix
|
self.stream_prefix = stream_prefix
|
||||||
self.answer_reached = False
|
self.answer_reached = False
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -72,6 +74,7 @@ class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|||||||
"""Run when LLM starts running."""
|
"""Run when LLM starts running."""
|
||||||
self.answer_reached = False
|
self.answer_reached = False
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
"""Run on new LLM token. Only available when streaming is enabled."""
|
"""Run on new LLM token. Only available when streaming is enabled."""
|
||||||
|
|
||||||
|
@ -388,7 +388,7 @@ except ImportError:
|
|||||||
class APIChain: # type: ignore[no-redef]
|
class APIChain: # type: ignore[no-redef]
|
||||||
"""Raise an ImportError if APIChain is used without langchain_community."""
|
"""Raise an ImportError if APIChain is used without langchain_community."""
|
||||||
|
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *_: Any, **__: Any) -> None:
|
||||||
"""Raise an ImportError if APIChain is used without langchain_community."""
|
"""Raise an ImportError if APIChain is used without langchain_community."""
|
||||||
msg = (
|
msg = (
|
||||||
"To use the APIChain, you must install the langchain_community package."
|
"To use the APIChain, you must install the langchain_community package."
|
||||||
|
@ -83,7 +83,7 @@ class BaseCombineDocumentsChain(Chain, ABC):
|
|||||||
"""
|
"""
|
||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]:
|
def prompt_length(self, docs: list[Document], **kwargs: Any) -> Optional[int]: # noqa: ARG002
|
||||||
"""Return the prompt length given the documents passed in.
|
"""Return the prompt length given the documents passed in.
|
||||||
|
|
||||||
This can be used by a caller to determine whether passing in a list
|
This can be used by a caller to determine whether passing in a list
|
||||||
|
@ -402,6 +402,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
|
|
||||||
return docs[:num_docs]
|
return docs[:num_docs]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
@ -416,6 +417,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
|
|||||||
)
|
)
|
||||||
return self._reduce_tokens_below_limit(docs)
|
return self._reduce_tokens_below_limit(docs)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
@ -512,6 +514,7 @@ class ChatVectorDBChain(BaseConversationalRetrievalChain):
|
|||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
def __getattr__(name: str = "") -> None:
|
def __getattr__(_: str = "") -> None:
|
||||||
"""Raise an error on import since is deprecated."""
|
"""Raise an error on import since is deprecated."""
|
||||||
msg = (
|
msg = (
|
||||||
"This module has been moved to langchain-experimental. "
|
"This module has been moved to langchain-experimental. "
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
def __getattr__(name: str = "") -> None:
|
def __getattr__(_: str = "") -> None:
|
||||||
"""Raise an error on import since is deprecated."""
|
"""Raise an error on import since is deprecated."""
|
||||||
msg = (
|
msg = (
|
||||||
"This module has been moved to langchain-experimental. "
|
"This module has been moved to langchain-experimental. "
|
||||||
|
@ -39,7 +39,7 @@ try:
|
|||||||
from langchain_community.llms.loading import load_llm, load_llm_from_config
|
from langchain_community.llms.loading import load_llm, load_llm_from_config
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def load_llm(*args: Any, **kwargs: Any) -> None:
|
def load_llm(*_: Any, **__: Any) -> None:
|
||||||
"""Import error for load_llm."""
|
"""Import error for load_llm."""
|
||||||
msg = (
|
msg = (
|
||||||
"To use this load_llm functionality you must install the "
|
"To use this load_llm functionality you must install the "
|
||||||
@ -48,7 +48,7 @@ except ImportError:
|
|||||||
)
|
)
|
||||||
raise ImportError(msg)
|
raise ImportError(msg)
|
||||||
|
|
||||||
def load_llm_from_config(*args: Any, **kwargs: Any) -> None:
|
def load_llm_from_config(*_: Any, **__: Any) -> None:
|
||||||
"""Import error for load_llm_from_config."""
|
"""Import error for load_llm_from_config."""
|
||||||
msg = (
|
msg = (
|
||||||
"To use this load_llm_from_config functionality you must install the "
|
"To use this load_llm_from_config functionality you must install the "
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks import (
|
|||||||
)
|
)
|
||||||
from langchain_core.utils import check_package_version, get_from_dict_or_env
|
from langchain_core.utils import check_package_version, get_from_dict_or_env
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
@ -105,6 +106,7 @@ class OpenAIModerationChain(Chain):
|
|||||||
return error_str
|
return error_str
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
|
@ -16,6 +16,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains import ReduceDocumentsChain
|
from langchain.chains import ReduceDocumentsChain
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -240,6 +241,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
"""
|
"""
|
||||||
return [self.input_docs_key, self.question_key]
|
return [self.input_docs_key, self.question_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -249,6 +251,7 @@ class QAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
"""Get docs to run questioning over."""
|
"""Get docs to run questioning over."""
|
||||||
return inputs.pop(self.input_docs_key)
|
return inputs.pop(self.input_docs_key)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_docs(
|
async def _aget_docs(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
||||||
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
from langchain.chains.qa_with_sources.base import BaseQAWithSourcesChain
|
||||||
@ -48,6 +49,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
|
|||||||
|
|
||||||
return docs[:num_docs]
|
return docs[:num_docs]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
|
@ -11,7 +11,7 @@ try:
|
|||||||
from lark import Lark, Transformer, v_args
|
from lark import Lark, Transformer, v_args
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
||||||
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore[misc]
|
def v_args(*_: Any, **__: Any) -> Any: # type: ignore[misc]
|
||||||
"""Dummy decorator for when lark is not installed."""
|
"""Dummy decorator for when lark is not installed."""
|
||||||
return lambda _: None
|
return lambda _: None
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@ from langchain_core.prompts import PromptTemplate
|
|||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic import ConfigDict, Field, model_validator
|
from pydantic import ConfigDict, Field, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
|
||||||
@ -330,6 +331,7 @@ class VectorDBQA(BaseRetrievalQA):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_docs(
|
def _get_docs(
|
||||||
self,
|
self,
|
||||||
question: str,
|
question: str,
|
||||||
|
@ -11,6 +11,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.router.base import RouterChain
|
from langchain.chains.router.base import RouterChain
|
||||||
|
|
||||||
@ -34,6 +35,7 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
"""
|
"""
|
||||||
return self.routing_keys
|
return self.routing_keys
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -43,6 +45,7 @@ class EmbeddingRouterChain(RouterChain):
|
|||||||
results = self.vectorstore.similarity_search(_input, k=1)
|
results = self.vectorstore.similarity_search(_input, k=1)
|
||||||
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
return {"next_inputs": inputs, "destination": results[0].metadata["name"]}
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
|
||||||
@ -63,6 +64,7 @@ class TransformChain(Chain):
|
|||||||
"""
|
"""
|
||||||
return self.output_variables
|
return self.output_variables
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
@ -70,6 +72,7 @@ class TransformChain(Chain):
|
|||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
return self.transform_cb(inputs)
|
return self.transform_cb(inputs)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
|
@ -331,6 +331,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|||||||
"""
|
"""
|
||||||
return ["prediction", "reference"]
|
return ["prediction", "reference"]
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -355,6 +356,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|||||||
score = self._compute_score(vectors)
|
score = self._compute_score(vectors)
|
||||||
return {"score": score}
|
return {"score": score}
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -382,6 +384,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|||||||
score = self._compute_score(vectors)
|
score = self._compute_score(vectors)
|
||||||
return {"score": score}
|
return {"score": score}
|
||||||
|
|
||||||
|
@override
|
||||||
def _evaluate_strings(
|
def _evaluate_strings(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -416,6 +419,7 @@ class EmbeddingDistanceEvalChain(_EmbeddingDistanceChainMixin, StringEvaluator):
|
|||||||
)
|
)
|
||||||
return self._prepare_output(result)
|
return self._prepare_output(result)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aevaluate_strings(
|
async def _aevaluate_strings(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -478,6 +482,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|||||||
"""Return the evaluation name."""
|
"""Return the evaluation name."""
|
||||||
return f"pairwise_embedding_{self.distance_metric.value}_distance"
|
return f"pairwise_embedding_{self.distance_metric.value}_distance"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -505,6 +510,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|||||||
score = self._compute_score(vectors)
|
score = self._compute_score(vectors)
|
||||||
return {"score": score}
|
return {"score": score}
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -532,6 +538,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|||||||
score = self._compute_score(vectors)
|
score = self._compute_score(vectors)
|
||||||
return {"score": score}
|
return {"score": score}
|
||||||
|
|
||||||
|
@override
|
||||||
def _evaluate_string_pairs(
|
def _evaluate_string_pairs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -567,6 +574,7 @@ class PairwiseEmbeddingDistanceEvalChain(
|
|||||||
)
|
)
|
||||||
return self._prepare_output(result)
|
return self._prepare_output(result)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aevaluate_string_pairs(
|
async def _aevaluate_string_pairs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.evaluation.schema import StringEvaluator
|
from langchain.evaluation.schema import StringEvaluator
|
||||||
|
|
||||||
|
|
||||||
@ -78,6 +80,7 @@ class ExactMatchStringEvaluator(StringEvaluator):
|
|||||||
"""
|
"""
|
||||||
return "exact_match"
|
return "exact_match"
|
||||||
|
|
||||||
|
@override
|
||||||
def _evaluate_strings( # type: ignore[override]
|
def _evaluate_strings( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -33,12 +33,9 @@ class JsonSchemaEvaluator(StringEvaluator):
|
|||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **_: Any) -> None:
|
||||||
"""Initializes the JsonSchemaEvaluator.
|
"""Initializes the JsonSchemaEvaluator.
|
||||||
|
|
||||||
Args:
|
|
||||||
kwargs: Additional keyword arguments.
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ImportError: If the jsonschema package is not installed.
|
ImportError: If the jsonschema package is not installed.
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.evaluation.schema import StringEvaluator
|
from langchain.evaluation.schema import StringEvaluator
|
||||||
|
|
||||||
|
|
||||||
@ -70,6 +72,7 @@ class RegexMatchStringEvaluator(StringEvaluator):
|
|||||||
"""
|
"""
|
||||||
return "regex_match"
|
return "regex_match"
|
||||||
|
|
||||||
|
@override
|
||||||
def _evaluate_strings( # type: ignore[override]
|
def _evaluate_strings( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -224,6 +224,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
|||||||
"""
|
"""
|
||||||
return f"{self.distance.value}_distance"
|
return f"{self.distance.value}_distance"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -242,6 +243,7 @@ class StringDistanceEvalChain(StringEvaluator, _RapidFuzzChainMixin):
|
|||||||
"""
|
"""
|
||||||
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
|
return {"score": self.compute_metric(inputs["reference"], inputs["prediction"])}
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -357,6 +359,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|||||||
"""
|
"""
|
||||||
return f"pairwise_{self.distance.value}_distance"
|
return f"pairwise_{self.distance.value}_distance"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -377,6 +380,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|||||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, Any],
|
inputs: dict[str, Any],
|
||||||
@ -397,6 +401,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|||||||
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
"score": self.compute_metric(inputs["prediction"], inputs["prediction_b"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
def _evaluate_string_pairs(
|
def _evaluate_string_pairs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@ -431,6 +436,7 @@ class PairwiseStringDistanceEvalChain(PairwiseStringEvaluator, _RapidFuzzChainMi
|
|||||||
)
|
)
|
||||||
return self._prepare_output(result)
|
return self._prepare_output(result)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aevaluate_string_pairs(
|
async def _aevaluate_string_pairs(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
@ -79,10 +79,12 @@ class ConversationBufferMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
|
||||||
|
@override
|
||||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return key-value pairs given the text input to the chain."""
|
"""Return key-value pairs given the text input to the chain."""
|
||||||
buffer = await self.abuffer()
|
buffer = await self.abuffer()
|
||||||
@ -133,6 +135,7 @@ class ConversationStringBufferMemory(BaseMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, str]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
@ -2,6 +2,7 @@ from typing import Any, Union
|
|||||||
|
|
||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ class ConversationBufferWindowMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
@ -9,6 +9,7 @@ from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_strin
|
|||||||
from langchain_core.prompts import BasePromptTemplate
|
from langchain_core.prompts import BasePromptTemplate
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
@ -133,6 +134,7 @@ class ConversationSummaryMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
if self.return_messages:
|
if self.return_messages:
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any, Union
|
|||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
from langchain.memory.summary import SummarizerMixin
|
from langchain.memory.summary import SummarizerMixin
|
||||||
@ -46,6 +47,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
buffer = self.chat_memory.messages
|
buffer = self.chat_memory.messages
|
||||||
@ -64,6 +66,7 @@ class ConversationSummaryBufferMemory(BaseChatMemory, SummarizerMixin):
|
|||||||
)
|
)
|
||||||
return {self.memory_key: final_buffer}
|
return {self.memory_key: final_buffer}
|
||||||
|
|
||||||
|
@override
|
||||||
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
async def aload_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Asynchronously return key-value pairs given the text input to the chain."""
|
"""Asynchronously return key-value pairs given the text input to the chain."""
|
||||||
buffer = await self.chat_memory.aget_messages()
|
buffer = await self.chat_memory.aget_messages()
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any
|
|||||||
from langchain_core._api import deprecated
|
from langchain_core._api import deprecated
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
from langchain_core.language_models import BaseLanguageModel
|
||||||
from langchain_core.messages import BaseMessage, get_buffer_string
|
from langchain_core.messages import BaseMessage, get_buffer_string
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.memory.chat_memory import BaseChatMemory
|
from langchain.memory.chat_memory import BaseChatMemory
|
||||||
|
|
||||||
@ -55,6 +56,7 @@ class ConversationTokenBufferMemory(BaseChatMemory):
|
|||||||
"""
|
"""
|
||||||
return [self.memory_key]
|
return [self.memory_key]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
def load_memory_variables(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Return history buffer."""
|
"""Return history buffer."""
|
||||||
return {self.memory_key: self.buffer}
|
return {self.memory_key: self.buffer}
|
||||||
|
@ -110,7 +110,7 @@ class LLMChainExtractor(BaseDocumentCompressor):
|
|||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: Optional[PromptTemplate] = None,
|
prompt: Optional[PromptTemplate] = None,
|
||||||
get_input: Optional[Callable[[str, Document], str]] = None,
|
get_input: Optional[Callable[[str, Document], str]] = None,
|
||||||
llm_chain_kwargs: Optional[dict] = None,
|
llm_chain_kwargs: Optional[dict] = None, # noqa: ARG003
|
||||||
) -> LLMChainExtractor:
|
) -> LLMChainExtractor:
|
||||||
"""Initialize from LLM."""
|
"""Initialize from LLM."""
|
||||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||||
|
@ -9,6 +9,7 @@ from langchain_core.callbacks import Callbacks
|
|||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
from langchain_core.utils import get_from_dict_or_env
|
from langchain_core.utils import get_from_dict_or_env
|
||||||
from pydantic import ConfigDict, model_validator
|
from pydantic import ConfigDict, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
@ -98,6 +99,7 @@ class CohereRerank(BaseDocumentCompressor):
|
|||||||
for res in results
|
for res in results
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@override
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
|
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
from langchain_core.callbacks import Callbacks
|
from langchain_core.callbacks import Callbacks
|
||||||
from langchain_core.documents import BaseDocumentCompressor, Document
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||||||
from pydantic import ConfigDict
|
from pydantic import ConfigDict
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
|
from langchain.retrievers.document_compressors.cross_encoder import BaseCrossEncoder
|
||||||
|
|
||||||
@ -25,6 +26,7 @@ class CrossEncoderReranker(BaseDocumentCompressor):
|
|||||||
extra="forbid",
|
extra="forbid",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
|
@ -6,6 +6,7 @@ from langchain_core.documents import BaseDocumentCompressor, Document
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.utils import pre_init
|
from langchain_core.utils import pre_init
|
||||||
from pydantic import ConfigDict, Field
|
from pydantic import ConfigDict, Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def _get_similarity_function() -> Callable:
|
def _get_similarity_function() -> Callable:
|
||||||
@ -50,6 +51,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def compress_documents(
|
def compress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
@ -93,6 +95,7 @@ class EmbeddingsFilter(BaseDocumentCompressor):
|
|||||||
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
stateful_documents[i].state["query_similarity_score"] = similarity[i]
|
||||||
return [stateful_documents[i] for i in included_idxs]
|
return [stateful_documents[i] for i in included_idxs]
|
||||||
|
|
||||||
|
@override
|
||||||
async def acompress_documents(
|
async def acompress_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
|
@ -67,7 +67,7 @@ class MultiQueryRetriever(BaseRetriever):
|
|||||||
retriever: BaseRetriever,
|
retriever: BaseRetriever,
|
||||||
llm: BaseLanguageModel,
|
llm: BaseLanguageModel,
|
||||||
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
|
prompt: BasePromptTemplate = DEFAULT_QUERY_PROMPT,
|
||||||
parser_key: Optional[str] = None,
|
parser_key: Optional[str] = None, # noqa: ARG003
|
||||||
include_original: bool = False, # noqa: FBT001,FBT002
|
include_original: bool = False, # noqa: FBT001,FBT002
|
||||||
) -> "MultiQueryRetriever":
|
) -> "MultiQueryRetriever":
|
||||||
"""Initialize from llm using default template.
|
"""Initialize from llm using default template.
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.retrievers import BaseRetriever
|
|||||||
from langchain_core.stores import BaseStore, ByteStore
|
from langchain_core.stores import BaseStore, ByteStore
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic import Field, model_validator
|
from pydantic import Field, model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.storage._lc_store import create_kv_docstore
|
from langchain.storage._lc_store import create_kv_docstore
|
||||||
|
|
||||||
@ -54,6 +55,7 @@ class MultiVectorRetriever(BaseRetriever):
|
|||||||
values["docstore"] = docstore
|
values["docstore"] = docstore
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -91,6 +93,7 @@ class MultiVectorRetriever(BaseRetriever):
|
|||||||
docs = self.docstore.mget(ids)
|
docs = self.docstore.mget(ids)
|
||||||
return [d for d in docs if d is not None]
|
return [d for d in docs if d is not None]
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
from pydantic import ConfigDict, Field
|
from pydantic import ConfigDict, Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
|
def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> float:
|
||||||
@ -128,6 +129,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
result.append(buffered_doc)
|
result.append(buffered_doc)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -142,6 +144,7 @@ class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
|||||||
docs_and_scores.update(self.get_salient_docs(query))
|
docs_and_scores.update(self.get_salient_docs(query))
|
||||||
return self._get_rescored_docs(docs_and_scores)
|
return self._get_rescored_docs(docs_and_scores)
|
||||||
|
|
||||||
|
@override
|
||||||
async def _aget_relevant_documents(
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -19,7 +19,6 @@ class ProgressBarCallback(base_callbacks.BaseCallbackHandler):
|
|||||||
total: int,
|
total: int,
|
||||||
ncols: int = 50,
|
ncols: int = 50,
|
||||||
end_with: str = "\n",
|
end_with: str = "\n",
|
||||||
**kwargs: Any,
|
|
||||||
):
|
):
|
||||||
"""Initialize the progress bar.
|
"""Initialize the progress bar.
|
||||||
|
|
||||||
|
@ -355,6 +355,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
|||||||
feedback.evaluator_info[RUN_KEY] = output[RUN_KEY]
|
feedback.evaluator_info[RUN_KEY] = output[RUN_KEY]
|
||||||
return feedback
|
return feedback
|
||||||
|
|
||||||
|
@override
|
||||||
def evaluate_run(
|
def evaluate_run(
|
||||||
self,
|
self,
|
||||||
run: Run,
|
run: Run,
|
||||||
@ -372,6 +373,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
|||||||
# TODO: Add run ID once we can declare it via callbacks
|
# TODO: Add run ID once we can declare it via callbacks
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
async def aevaluate_run(
|
async def aevaluate_run(
|
||||||
self,
|
self,
|
||||||
run: Run,
|
run: Run,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def __getattr__(name: str = "") -> Any:
|
def __getattr__(_: str = "") -> Any:
|
||||||
msg = (
|
msg = (
|
||||||
"This tool has been moved to langchain experiment. "
|
"This tool has been moved to langchain experiment. "
|
||||||
"This tool has access to a python REPL. "
|
"This tool has access to a python REPL. "
|
||||||
|
@ -145,6 +145,7 @@ ignore-words-list = "momento,collison,ned,foor,reworkd,parth,whats,aapply,mysogy
|
|||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
select = [
|
select = [
|
||||||
"A", # flake8-builtins
|
"A", # flake8-builtins
|
||||||
|
"ARG", # flake8-unused-arguments
|
||||||
"ASYNC", # flake8-async
|
"ASYNC", # flake8-async
|
||||||
"B", # flake8-bugbear
|
"B", # flake8-bugbear
|
||||||
"C4", # flake8-comprehensions
|
"C4", # flake8-comprehensions
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
fake_texts = ["foo", "bar", "baz"]
|
fake_texts = ["foo", "bar", "baz"]
|
||||||
|
|
||||||
@ -18,6 +19,7 @@ class FakeEmbeddings(Embeddings):
|
|||||||
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
return self.embed_documents(texts)
|
return self.embed_documents(texts)
|
||||||
|
|
||||||
|
@override
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Return constant query embeddings.
|
"""Return constant query embeddings.
|
||||||
Embeddings are identical to embed_documents(texts)[0].
|
Embeddings are identical to embed_documents(texts)[0].
|
||||||
|
@ -25,6 +25,7 @@ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|||||||
from langchain_core.runnables.utils import add
|
from langchain_core.runnables.utils import add
|
||||||
from langchain_core.tools import Tool, tool
|
from langchain_core.tools import Tool, tool
|
||||||
from langchain_core.tracers import RunLog, RunLogPatch
|
from langchain_core.tracers import RunLog, RunLogPatch
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.agents import (
|
from langchain.agents import (
|
||||||
AgentExecutor,
|
AgentExecutor,
|
||||||
@ -48,6 +49,7 @@ class FakeListLLM(LLM):
|
|||||||
responses: list[str]
|
responses: list[str]
|
||||||
i: int = -1
|
i: int = -1
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
@ -462,7 +464,7 @@ async def test_runnable_agent() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
"""A parser."""
|
"""A parser."""
|
||||||
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
|
return AgentFinish(return_values={"foo": "meow"}, log="hard-coded-message")
|
||||||
|
|
||||||
@ -569,7 +571,7 @@ async def test_runnable_agent_with_function_calls() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
"""A parser."""
|
"""A parser."""
|
||||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||||
|
|
||||||
@ -681,7 +683,7 @@ async def test_runnable_with_multi_action_per_step() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def fake_parse(inputs: dict) -> Union[AgentFinish, AgentAction]:
|
def fake_parse(_: dict) -> Union[AgentFinish, AgentAction]:
|
||||||
"""A parser."""
|
"""A parser."""
|
||||||
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
return cast("Union[AgentFinish, AgentAction]", next(parser_responses))
|
||||||
|
|
||||||
@ -1032,7 +1034,7 @@ async def test_openai_agent_tools_agent() -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
GenericFakeChatModel.bind_tools = lambda self, x: self # type: ignore[assignment,misc]
|
GenericFakeChatModel.bind_tools = lambda self, _: self # type: ignore[assignment,misc]
|
||||||
model = GenericFakeChatModel(messages=infinite_cycle)
|
model = GenericFakeChatModel(messages=infinite_cycle)
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.language_models.llms import LLM
|
|||||||
from langchain_core.messages import AIMessage, HumanMessage
|
from langchain_core.messages import AIMessage, HumanMessage
|
||||||
from langchain_core.runnables.utils import add
|
from langchain_core.runnables.utils import add
|
||||||
from langchain_core.tools import Tool
|
from langchain_core.tools import Tool
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
from langchain.agents import AgentExecutor, AgentType, initialize_agent
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
@ -19,6 +20,7 @@ class FakeListLLM(LLM):
|
|||||||
responses: list[str]
|
responses: list[str]
|
||||||
i: int = -1
|
i: int = -1
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -364,7 +364,7 @@ def test_agent_iterator_failing_tool() -> None:
|
|||||||
tools = [
|
tools = [
|
||||||
Tool(
|
Tool(
|
||||||
name="FailingTool",
|
name="FailingTool",
|
||||||
func=lambda x: 1 / 0, # This tool will raise a ZeroDivisionError
|
func=lambda _: 1 / 0, # This tool will raise a ZeroDivisionError
|
||||||
description="A tool that fails",
|
description="A tool that fails",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
@ -8,7 +8,7 @@ from tests.unit_tests.llms.fake_llm import FakeLLM
|
|||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def my_tool(query: str) -> str:
|
def my_tool(query: str) -> str: # noqa: ARG001
|
||||||
"""A fake tool."""
|
"""A fake tool."""
|
||||||
return "fake tool"
|
return "fake tool"
|
||||||
|
|
||||||
|
@ -141,11 +141,11 @@ def test_valid_action_and_answer_raises_exception() -> None:
|
|||||||
def test_from_chains() -> None:
|
def test_from_chains() -> None:
|
||||||
"""Test initializing from chains."""
|
"""Test initializing from chains."""
|
||||||
chain_configs = [
|
chain_configs = [
|
||||||
Tool(name="foo", func=lambda x: "foo", description="foobar1"),
|
Tool(name="foo", func=lambda _x: "foo", description="foobar1"),
|
||||||
Tool(name="bar", func=lambda x: "bar", description="foobar2"),
|
Tool(name="bar", func=lambda _x: "bar", description="foobar2"),
|
||||||
]
|
]
|
||||||
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
|
agent = ZeroShotAgent.from_llm_and_tools(FakeLLM(), chain_configs)
|
||||||
expected_tools_prompt = "foo(x) - foobar1\nbar(x) - foobar2"
|
expected_tools_prompt = "foo(_x) - foobar1\nbar(_x) - foobar2"
|
||||||
expected_tool_names = "foo, bar"
|
expected_tool_names = "foo, bar"
|
||||||
expected_template = "\n\n".join(
|
expected_template = "\n\n".join(
|
||||||
[
|
[
|
||||||
|
@ -7,7 +7,7 @@ import pytest
|
|||||||
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
from langchain.agents.openai_assistant import OpenAIAssistantRunnable
|
||||||
|
|
||||||
|
|
||||||
def _create_mock_client(*args: Any, use_async: bool = False, **kwargs: Any) -> Any:
|
def _create_mock_client(*_: Any, use_async: bool = False, **__: Any) -> Any:
|
||||||
client = AsyncMock() if use_async else MagicMock()
|
client = AsyncMock() if use_async else MagicMock()
|
||||||
mock_assistant = MagicMock()
|
mock_assistant = MagicMock()
|
||||||
mock_assistant.id = "abc123"
|
mock_assistant.id = "abc123"
|
||||||
|
@ -7,6 +7,7 @@ from uuid import UUID
|
|||||||
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class BaseFakeCallbackHandler(BaseModel):
|
class BaseFakeCallbackHandler(BaseModel):
|
||||||
@ -135,6 +136,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
"""Whether to ignore retriever callbacks."""
|
"""Whether to ignore retriever callbacks."""
|
||||||
return self.ignore_retriever_
|
return self.ignore_retriever_
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_start(
|
def on_llm_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -142,6 +144,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_start_common()
|
self.on_llm_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_new_token(
|
def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -149,6 +152,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_new_token_common()
|
self.on_llm_new_token_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_end(
|
def on_llm_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -156,6 +160,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_end_common()
|
self.on_llm_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_llm_error(
|
def on_llm_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -163,6 +168,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_llm_error_common()
|
self.on_llm_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retry(
|
def on_retry(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -170,6 +176,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retry_common()
|
self.on_retry_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -177,6 +184,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_start_common()
|
self.on_chain_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_end(
|
def on_chain_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -184,6 +192,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_end_common()
|
self.on_chain_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_chain_error(
|
def on_chain_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -191,6 +200,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_chain_error_common()
|
self.on_chain_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_start(
|
def on_tool_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -198,6 +208,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_start_common()
|
self.on_tool_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_end(
|
def on_tool_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -205,6 +216,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_end_common()
|
self.on_tool_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_tool_error(
|
def on_tool_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -212,6 +224,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_tool_error_common()
|
self.on_tool_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_action(
|
def on_agent_action(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -219,6 +232,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_agent_action_common()
|
self.on_agent_action_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_agent_finish(
|
def on_agent_finish(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -226,6 +240,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_agent_finish_common()
|
self.on_agent_finish_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_text(
|
def on_text(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -233,6 +248,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_text_common()
|
self.on_text_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_start(
|
def on_retriever_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -240,6 +256,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retriever_start_common()
|
self.on_retriever_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_end(
|
def on_retriever_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -247,6 +264,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retriever_end_common()
|
self.on_retriever_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_retriever_error(
|
def on_retriever_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -259,6 +277,7 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
|
|||||||
|
|
||||||
|
|
||||||
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
class FakeCallbackHandlerWithChatStart(FakeCallbackHandler):
|
||||||
|
@override
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@ -290,6 +309,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
"""Whether to ignore agent callbacks."""
|
"""Whether to ignore agent callbacks."""
|
||||||
return self.ignore_agent_
|
return self.ignore_agent_
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_retry(
|
async def on_retry(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -297,6 +317,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
self.on_retry_common()
|
self.on_retry_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_start(
|
async def on_llm_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -304,6 +325,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_start_common()
|
self.on_llm_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -311,6 +333,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_new_token_common()
|
self.on_llm_new_token_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_end(
|
async def on_llm_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -318,6 +341,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_end_common()
|
self.on_llm_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_error(
|
async def on_llm_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -325,6 +349,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_llm_error_common()
|
self.on_llm_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_start(
|
async def on_chain_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -332,6 +357,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_start_common()
|
self.on_chain_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_end(
|
async def on_chain_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -339,6 +365,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_end_common()
|
self.on_chain_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_chain_error(
|
async def on_chain_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -346,6 +373,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_chain_error_common()
|
self.on_chain_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_start(
|
async def on_tool_start(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -353,6 +381,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_start_common()
|
self.on_tool_start_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_end(
|
async def on_tool_end(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -360,6 +389,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_end_common()
|
self.on_tool_end_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_tool_error(
|
async def on_tool_error(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -367,6 +397,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_tool_error_common()
|
self.on_tool_error_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_agent_action(
|
async def on_agent_action(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -374,6 +405,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_agent_action_common()
|
self.on_agent_action_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_agent_finish(
|
async def on_agent_finish(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
@ -381,6 +413,7 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.on_agent_finish_common()
|
self.on_agent_finish_common()
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_text(
|
async def on_text(
|
||||||
self,
|
self,
|
||||||
*args: Any,
|
*args: Any,
|
||||||
|
@ -3,6 +3,7 @@ import re
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.callbacks import FileCallbackHandler
|
from langchain.callbacks import FileCallbackHandler
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -25,6 +26,7 @@ class FakeChain(Chain):
|
|||||||
"""Output key of bar."""
|
"""Output key of bar."""
|
||||||
return self.the_output_keys
|
return self.the_output_keys
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
|
@ -2,6 +2,7 @@ from typing import Any, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.callbacks import CallbackManagerForChainRun
|
from langchain_core.callbacks import CallbackManagerForChainRun
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.callbacks import StdOutCallbackHandler
|
from langchain.callbacks import StdOutCallbackHandler
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
@ -24,6 +25,7 @@ class FakeChain(Chain):
|
|||||||
"""Output key of bar."""
|
"""Output key of bar."""
|
||||||
return self.the_output_keys
|
return self.the_output_keys
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
|
@ -7,6 +7,7 @@ import pytest
|
|||||||
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
from langchain_core.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
from langchain_core.tracers.context import collect_runs
|
from langchain_core.tracers.context import collect_runs
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import RUN_KEY
|
from langchain.schema import RUN_KEY
|
||||||
@ -21,6 +22,7 @@ class FakeMemory(BaseMemory):
|
|||||||
"""Return baz variable."""
|
"""Return baz variable."""
|
||||||
return ["baz"]
|
return ["baz"]
|
||||||
|
|
||||||
|
@override
|
||||||
def load_memory_variables(
|
def load_memory_variables(
|
||||||
self,
|
self,
|
||||||
inputs: Optional[dict[str, Any]] = None,
|
inputs: Optional[dict[str, Any]] = None,
|
||||||
@ -52,6 +54,7 @@ class FakeChain(Chain):
|
|||||||
"""Output key of bar."""
|
"""Output key of bar."""
|
||||||
return self.the_output_keys
|
return self.the_output_keys
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
|
@ -19,7 +19,7 @@ def _fake_docs_len_func(docs: list[Document]) -> int:
|
|||||||
return len(_fake_combine_docs_func(docs))
|
return len(_fake_combine_docs_func(docs))
|
||||||
|
|
||||||
|
|
||||||
def _fake_combine_docs_func(docs: list[Document], **kwargs: Any) -> str:
|
def _fake_combine_docs_func(docs: list[Document], **_: Any) -> str:
|
||||||
return "".join([d.page_content for d in docs])
|
return "".join([d.page_content for d in docs])
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks import CallbackManagerForLLMRun
|
|||||||
from langchain_core.language_models import LLM
|
from langchain_core.language_models import LLM
|
||||||
from langchain_core.memory import BaseMemory
|
from langchain_core.memory import BaseMemory
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.conversation.base import ConversationChain
|
from langchain.chains.conversation.base import ConversationChain
|
||||||
from langchain.memory.buffer import ConversationBufferMemory
|
from langchain.memory.buffer import ConversationBufferMemory
|
||||||
@ -26,6 +27,7 @@ class DummyLLM(LLM):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "dummy"
|
return "dummy"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -10,6 +10,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.language_models.llms import BaseLLM
|
from langchain_core.language_models.llms import BaseLLM
|
||||||
from langchain_core.outputs import Generation, LLMResult
|
from langchain_core.outputs import Generation, LLMResult
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
||||||
from langchain.chains.hyde.prompts import PROMPT_MAP
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
||||||
@ -18,10 +19,12 @@ from langchain.chains.hyde.prompts import PROMPT_MAP
|
|||||||
class FakeEmbeddings(Embeddings):
|
class FakeEmbeddings(Embeddings):
|
||||||
"""Fake embedding class for tests."""
|
"""Fake embedding class for tests."""
|
||||||
|
|
||||||
|
@override
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
"""Return random floats."""
|
"""Return random floats."""
|
||||||
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
||||||
|
|
||||||
|
@override
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
"""Return random floats."""
|
"""Return random floats."""
|
||||||
return list(np.random.uniform(0, 1, 10))
|
return list(np.random.uniform(0, 1, 10))
|
||||||
@ -32,6 +35,7 @@ class FakeLLM(BaseLLM):
|
|||||||
|
|
||||||
n: int = 1
|
n: int = 1
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
@ -41,6 +45,7 @@ class FakeLLM(BaseLLM):
|
|||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
||||||
|
|
||||||
|
@override
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
prompts: list[str],
|
prompts: list[str],
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.callbacks.manager import (
|
|||||||
AsyncCallbackManagerForChainRun,
|
AsyncCallbackManagerForChainRun,
|
||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||||
@ -32,6 +33,7 @@ class FakeChain(Chain):
|
|||||||
"""Input keys this chain returns."""
|
"""Input keys this chain returns."""
|
||||||
return self.output_variables
|
return self.output_variables
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
@ -43,6 +45,7 @@ class FakeChain(Chain):
|
|||||||
outputs[var] = f"{' '.join(variables)}foo"
|
outputs[var] = f"{' '.join(variables)}foo"
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@override
|
||||||
async def _acall(
|
async def _acall(
|
||||||
self,
|
self,
|
||||||
inputs: dict[str, str],
|
inputs: dict[str, str],
|
||||||
|
@ -4,6 +4,7 @@ from collections.abc import Iterator
|
|||||||
|
|
||||||
from langchain_core.document_loaders import BaseBlobParser, Blob
|
from langchain_core.document_loaders import BaseBlobParser, Blob
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
def test_base_blob_parser() -> None:
|
def test_base_blob_parser() -> None:
|
||||||
@ -12,6 +13,7 @@ def test_base_blob_parser() -> None:
|
|||||||
class MyParser(BaseBlobParser):
|
class MyParser(BaseBlobParser):
|
||||||
"""A simple parser that returns a single document."""
|
"""A simple parser that returns a single document."""
|
||||||
|
|
||||||
|
@override
|
||||||
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
def lazy_parse(self, blob: Blob) -> Iterator[Document]:
|
||||||
"""Lazy parsing interface."""
|
"""Lazy parsing interface."""
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -7,12 +7,14 @@ import warnings
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.embeddings import CacheBackedEmbeddings
|
from langchain.embeddings import CacheBackedEmbeddings
|
||||||
from langchain.storage.in_memory import InMemoryStore
|
from langchain.storage.in_memory import InMemoryStore
|
||||||
|
|
||||||
|
|
||||||
class MockEmbeddings(Embeddings):
|
class MockEmbeddings(Embeddings):
|
||||||
|
@override
|
||||||
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
def embed_documents(self, texts: list[str]) -> list[list[float]]:
|
||||||
# Simulate embedding documents
|
# Simulate embedding documents
|
||||||
embeddings: list[list[float]] = []
|
embeddings: list[list[float]] = []
|
||||||
@ -23,6 +25,7 @@ class MockEmbeddings(Embeddings):
|
|||||||
embeddings.append([len(text), len(text) + 1])
|
embeddings.append([len(text), len(text) + 1])
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
@override
|
||||||
def embed_query(self, text: str) -> list[float]:
|
def embed_query(self, text: str) -> list[float]:
|
||||||
# Simulate embedding a query
|
# Simulate embedding a query
|
||||||
return [5.0, 6.0]
|
return [5.0, 6.0]
|
||||||
|
@ -9,6 +9,7 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.evaluation.agents.trajectory_eval_chain import (
|
from langchain.evaluation.agents.trajectory_eval_chain import (
|
||||||
TrajectoryEval,
|
TrajectoryEval,
|
||||||
@ -43,6 +44,7 @@ class _FakeTrajectoryChatModel(FakeChatModel):
|
|||||||
sequential_responses: Optional[bool] = False
|
sequential_responses: Optional[bool] = False
|
||||||
response_index: int = 0
|
response_index: int = 0
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -13,6 +13,7 @@ from langchain_core.documents import Document
|
|||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.indexing.api import _abatch, _get_document_with_hash
|
from langchain_core.indexing.api import _abatch, _get_document_with_hash
|
||||||
from langchain_core.vectorstores import VST, VectorStore
|
from langchain_core.vectorstores import VST, VectorStore
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.indexes import aindex, index
|
from langchain.indexes import aindex, index
|
||||||
from langchain.indexes._sql_record_manager import SQLRecordManager
|
from langchain.indexes._sql_record_manager import SQLRecordManager
|
||||||
@ -45,18 +46,21 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
self.store: dict[str, Document] = {}
|
self.store: dict[str, Document] = {}
|
||||||
self.permit_upserts = permit_upserts
|
self.permit_upserts = permit_upserts
|
||||||
|
|
||||||
|
@override
|
||||||
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
def delete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||||
"""Delete the given documents from the store using their IDs."""
|
"""Delete the given documents from the store using their IDs."""
|
||||||
if ids:
|
if ids:
|
||||||
for _id in ids:
|
for _id in ids:
|
||||||
self.store.pop(_id, None)
|
self.store.pop(_id, None)
|
||||||
|
|
||||||
|
@override
|
||||||
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
async def adelete(self, ids: Optional[Sequence[str]] = None, **kwargs: Any) -> None:
|
||||||
"""Delete the given documents from the store using their IDs."""
|
"""Delete the given documents from the store using their IDs."""
|
||||||
if ids:
|
if ids:
|
||||||
for _id in ids:
|
for _id in ids:
|
||||||
self.store.pop(_id, None)
|
self.store.pop(_id, None)
|
||||||
|
|
||||||
|
@override
|
||||||
def add_documents(
|
def add_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
@ -81,6 +85,7 @@ class InMemoryVectorStore(VectorStore):
|
|||||||
|
|
||||||
return list(ids)
|
return list(ids)
|
||||||
|
|
||||||
|
@override
|
||||||
async def aadd_documents(
|
async def aadd_documents(
|
||||||
self,
|
self,
|
||||||
documents: Sequence[Document],
|
documents: Sequence[Document],
|
||||||
|
@ -16,11 +16,13 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from langchain_core.runnables import run_in_executor
|
from langchain_core.runnables import run_in_executor
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class FakeChatModel(SimpleChatModel):
|
class FakeChatModel(SimpleChatModel):
|
||||||
"""Fake Chat Model wrapper for testing purposes."""
|
"""Fake Chat Model wrapper for testing purposes."""
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -30,6 +32,7 @@ class FakeChatModel(SimpleChatModel):
|
|||||||
) -> str:
|
) -> str:
|
||||||
return "fake response"
|
return "fake response"
|
||||||
|
|
||||||
|
@override
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -74,6 +77,7 @@ class GenericFakeChatModel(BaseChatModel):
|
|||||||
into message chunks.
|
into message chunks.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, Optional, cast
|
|||||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||||
from langchain_core.language_models.llms import LLM
|
from langchain_core.language_models.llms import LLM
|
||||||
from pydantic import model_validator
|
from pydantic import model_validator
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class FakeLLM(LLM):
|
class FakeLLM(LLM):
|
||||||
@ -32,6 +33,7 @@ class FakeLLM(LLM):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "fake"
|
return "fake"
|
||||||
|
|
||||||
|
@override
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
@ -7,6 +7,7 @@ from uuid import UUID
|
|||||||
from langchain_core.callbacks.base import AsyncCallbackHandler
|
from langchain_core.callbacks.base import AsyncCallbackHandler
|
||||||
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
||||||
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
from tests.unit_tests.llms.fake_chat_model import GenericFakeChatModel
|
||||||
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
from tests.unit_tests.stubs import _AnyIdAIMessage, _AnyIdAIMessageChunk
|
||||||
@ -166,6 +167,7 @@ async def test_callback_handlers() -> None:
|
|||||||
# Required to implement since this is an abstract method
|
# Required to implement since this is an abstract method
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
async def on_llm_new_token(
|
async def on_llm_new_token(
|
||||||
self,
|
self,
|
||||||
token: str,
|
token: str,
|
||||||
|
@ -8,6 +8,7 @@ from langchain_core.messages import AIMessage
|
|||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.prompts.prompt import PromptTemplate
|
from langchain_core.prompts.prompt import PromptTemplate
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||||
@ -21,6 +22,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
|||||||
parse_count: int = 0 # Number of times parse has been called
|
parse_count: int = 0 # Number of times parse has been called
|
||||||
attemp_count_before_success: int # Number of times to fail before succeeding
|
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||||
|
|
||||||
|
@override
|
||||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||||
self.parse_count += 1
|
self.parse_count += 1
|
||||||
if self.parse_count <= self.attemp_count_before_success:
|
if self.parse_count <= self.attemp_count_before_success:
|
||||||
@ -62,7 +64,7 @@ def test_output_fixing_parser_parse(
|
|||||||
|
|
||||||
|
|
||||||
def test_output_fixing_parser_from_llm() -> None:
|
def test_output_fixing_parser_from_llm() -> None:
|
||||||
def fake_llm(prompt: str) -> AIMessage:
|
def fake_llm(_: str) -> AIMessage:
|
||||||
return AIMessage("2024-07-08T00:00:00.000000Z")
|
return AIMessage("2024-07-08T00:00:00.000000Z")
|
||||||
|
|
||||||
llm = RunnableLambda(fake_llm)
|
llm = RunnableLambda(fake_llm)
|
||||||
|
@ -7,6 +7,7 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.output_parsers import BaseOutputParser
|
from langchain_core.output_parsers import BaseOutputParser
|
||||||
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
from langchain_core.prompt_values import PromptValue, StringPromptValue
|
||||||
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
from langchain.output_parsers.datetime import DatetimeOutputParser
|
from langchain.output_parsers.datetime import DatetimeOutputParser
|
||||||
@ -25,6 +26,7 @@ class SuccessfulParseAfterRetries(BaseOutputParser[str]):
|
|||||||
attemp_count_before_success: int # Number of times to fail before succeeding
|
attemp_count_before_success: int # Number of times to fail before succeeding
|
||||||
error_msg: str = "error"
|
error_msg: str = "error"
|
||||||
|
|
||||||
|
@override
|
||||||
def parse(self, *args: Any, **kwargs: Any) -> str:
|
def parse(self, *args: Any, **kwargs: Any) -> str:
|
||||||
self.parse_count += 1
|
self.parse_count += 1
|
||||||
if self.parse_count <= self.attemp_count_before_success:
|
if self.parse_count <= self.attemp_count_before_success:
|
||||||
|
@ -14,6 +14,7 @@ from langchain_core.structured_query import (
|
|||||||
StructuredQuery,
|
StructuredQuery,
|
||||||
Visitor,
|
Visitor,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.chains.query_constructor.schema import AttributeInfo
|
from langchain.chains.query_constructor.schema import AttributeInfo
|
||||||
from langchain.retrievers import SelfQueryRetriever
|
from langchain.retrievers import SelfQueryRetriever
|
||||||
@ -61,6 +62,7 @@ class FakeTranslator(Visitor):
|
|||||||
|
|
||||||
|
|
||||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
@override
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -1,5 +1,8 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
class SequentialRetriever(BaseRetriever):
|
class SequentialRetriever(BaseRetriever):
|
||||||
@ -8,17 +11,21 @@ class SequentialRetriever(BaseRetriever):
|
|||||||
sequential_responses: list[list[Document]]
|
sequential_responses: list[list[Document]]
|
||||||
response_index: int = 0
|
response_index: int = 0
|
||||||
|
|
||||||
def _get_relevant_documents( # type: ignore[override]
|
@override
|
||||||
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
if self.response_index >= len(self.sequential_responses):
|
if self.response_index >= len(self.sequential_responses):
|
||||||
return []
|
return []
|
||||||
self.response_index += 1
|
self.response_index += 1
|
||||||
return self.sequential_responses[self.response_index - 1]
|
return self.sequential_responses[self.response_index - 1]
|
||||||
|
|
||||||
async def _aget_relevant_documents( # type: ignore[override]
|
@override
|
||||||
|
async def _aget_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
**kwargs: Any,
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
return self._get_relevant_documents(query)
|
return self._get_relevant_documents(query)
|
||||||
|
@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.retrievers import BaseRetriever
|
from langchain_core.retrievers import BaseRetriever
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.retrievers.ensemble import EnsembleRetriever
|
from langchain.retrievers.ensemble import EnsembleRetriever
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ from langchain.retrievers.ensemble import EnsembleRetriever
|
|||||||
class MockRetriever(BaseRetriever):
|
class MockRetriever(BaseRetriever):
|
||||||
docs: list[Document]
|
docs: list[Document]
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_relevant_documents(
|
def _get_relevant_documents(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
|
from langchain.retrievers.multi_vector import MultiVectorRetriever, SearchType
|
||||||
from langchain.storage import InMemoryStore
|
from langchain.storage import InMemoryStore
|
||||||
@ -15,6 +16,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
|||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
return self._identity_fn
|
return self._identity_fn
|
||||||
|
|
||||||
|
@override
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -26,6 +28,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
|||||||
return []
|
return []
|
||||||
return [res]
|
return [res]
|
||||||
|
|
||||||
|
@override
|
||||||
def similarity_search_with_score(
|
def similarity_search_with_score(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -3,6 +3,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_text_splitters.character import CharacterTextSplitter
|
from langchain_text_splitters.character import CharacterTextSplitter
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.retrievers import ParentDocumentRetriever
|
from langchain.retrievers import ParentDocumentRetriever
|
||||||
from langchain.storage import InMemoryStore
|
from langchain.storage import InMemoryStore
|
||||||
@ -10,6 +11,7 @@ from tests.unit_tests.indexes.test_indexing import InMemoryVectorStore
|
|||||||
|
|
||||||
|
|
||||||
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
||||||
|
@override
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -21,6 +23,7 @@ class InMemoryVectorstoreWithSearch(InMemoryVectorStore):
|
|||||||
return []
|
return []
|
||||||
return [res]
|
return [res]
|
||||||
|
|
||||||
|
@override
|
||||||
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]:
|
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> list[str]:
|
||||||
print(documents) # noqa: T201
|
print(documents) # noqa: T201
|
||||||
return super().add_documents(
|
return super().add_documents(
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
from langchain_core.documents import Document
|
from langchain_core.documents import Document
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.vectorstores import VectorStore
|
from langchain_core.vectorstores import VectorStore
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.retrievers.time_weighted_retriever import (
|
from langchain.retrievers.time_weighted_retriever import (
|
||||||
TimeWeightedVectorStoreRetriever,
|
TimeWeightedVectorStoreRetriever,
|
||||||
@ -31,6 +32,7 @@ def _get_example_memories(k: int = 4) -> list[Document]:
|
|||||||
class MockVectorStore(VectorStore):
|
class MockVectorStore(VectorStore):
|
||||||
"""Mock invalid vector store."""
|
"""Mock invalid vector store."""
|
||||||
|
|
||||||
|
@override
|
||||||
def add_texts(
|
def add_texts(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -39,6 +41,7 @@ class MockVectorStore(VectorStore):
|
|||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
return list(texts)
|
return list(texts)
|
||||||
|
|
||||||
|
@override
|
||||||
def similarity_search(
|
def similarity_search(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
@ -48,6 +51,7 @@ class MockVectorStore(VectorStore):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@override
|
||||||
def from_texts(
|
def from_texts(
|
||||||
cls: type["MockVectorStore"],
|
cls: type["MockVectorStore"],
|
||||||
texts: list[str],
|
texts: list[str],
|
||||||
@ -57,6 +61,7 @@ class MockVectorStore(VectorStore):
|
|||||||
) -> "MockVectorStore":
|
) -> "MockVectorStore":
|
||||||
return cls()
|
return cls()
|
||||||
|
|
||||||
|
@override
|
||||||
def _similarity_search_with_relevance_scores(
|
def _similarity_search_with_relevance_scores(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
|
@ -38,7 +38,7 @@ repo_dict = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def repo_lookup(owner_repo_commit: str, **kwargs: Any) -> ChatPromptTemplate:
|
def repo_lookup(owner_repo_commit: str, **_: Any) -> ChatPromptTemplate:
|
||||||
return repo_dict[owner_repo_commit]
|
return repo_dict[owner_repo_commit]
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from langchain_core.messages import AIMessage, BaseMessage
|
|||||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
from langchain.runnables.openai_functions import OpenAIFunctionsRouter
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ class FakeChatOpenAI(BaseChatModel):
|
|||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
return "fake-openai-chat-model"
|
return "fake-openai-chat-model"
|
||||||
|
|
||||||
|
@override
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
|
@ -3,16 +3,14 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from langchain_core.language_models import BaseLanguageModel
|
|
||||||
from langsmith.client import Client
|
from langsmith.client import Client
|
||||||
from langsmith.schemas import Dataset, Example
|
from langsmith.schemas import Dataset, Example
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from langchain.chains.transform import TransformChain
|
from langchain.chains.transform import TransformChain
|
||||||
from langchain.smith.evaluation.runner_utils import (
|
from langchain.smith.evaluation.runner_utils import (
|
||||||
InputFormatError,
|
InputFormatError,
|
||||||
@ -243,7 +241,7 @@ def test_run_chat_model_all_formats(inputs: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
@freeze_time("2023-01-01")
|
||||||
async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_arun_on_dataset() -> None:
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
name="test",
|
name="test",
|
||||||
@ -298,22 +296,20 @@ async def test_arun_on_dataset(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def mock_read_dataset(*args: Any, **kwargs: Any) -> Dataset:
|
def mock_read_dataset(*_: Any, **__: Any) -> Dataset:
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def mock_list_examples(*args: Any, **kwargs: Any) -> Iterator[Example]:
|
def mock_list_examples(*_: Any, **__: Any) -> Iterator[Example]:
|
||||||
return iter(examples)
|
return iter(examples)
|
||||||
|
|
||||||
async def mock_arun_chain(
|
async def mock_arun_chain(
|
||||||
example: Example,
|
example: Example,
|
||||||
llm_or_chain: Union[BaseLanguageModel, Chain],
|
*_: Any,
|
||||||
tags: Optional[list[str]] = None,
|
**__: Any,
|
||||||
callbacks: Optional[Any] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {"result": f"Result for example {example.id}"}
|
return {"result": f"Result for example {example.id}"}
|
||||||
|
|
||||||
def mock_create_project(*args: Any, **kwargs: Any) -> Any:
|
def mock_create_project(*_: Any, **__: Any) -> Any:
|
||||||
proj = mock.MagicMock()
|
proj = mock.MagicMock()
|
||||||
proj.id = "123"
|
proj.id = "123"
|
||||||
return proj
|
return proj
|
||||||
|
@ -8,13 +8,13 @@ from langchain.tools.render import (
|
|||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def search(query: str) -> str:
|
def search(query: str) -> str: # noqa: ARG001
|
||||||
"""Lookup things online."""
|
"""Lookup things online."""
|
||||||
return "foo"
|
return "foo"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def calculator(expression: str) -> str:
|
def calculator(expression: str) -> str: # noqa: ARG001
|
||||||
"""Do math."""
|
"""Do math."""
|
||||||
return "bar"
|
return "bar"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user