docstrings experimental (#7969)

- added/changed docstring for `experimental`
- added/changed docstrings for different artifacts
- 
@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-24 14:21:48 -07:00 committed by GitHub
parent 3eb4112a1f
commit c580c81cca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 133 additions and 30 deletions

View File

@ -7,14 +7,20 @@ from langchain.schema import BaseOutputParser
class AutoGPTAction(NamedTuple): class AutoGPTAction(NamedTuple):
"""Action for AutoGPT."""
name: str name: str
"""Name of the action."""
args: Dict args: Dict
"""Arguments for the action."""
class BaseAutoGPTOutputParser(BaseOutputParser): class BaseAutoGPTOutputParser(BaseOutputParser):
"""Base class for AutoGPT output parsers."""
@abstractmethod @abstractmethod
def parse(self, text: str) -> AutoGPTAction: def parse(self, text: str) -> AutoGPTAction:
"""Return AutoGPTAction""" """Parse text and return AutoGPTAction"""
def preprocess_json_input(input_str: str) -> str: def preprocess_json_input(input_str: str) -> str:
@ -36,6 +42,8 @@ def preprocess_json_input(input_str: str) -> str:
class AutoGPTOutputParser(BaseAutoGPTOutputParser): class AutoGPTOutputParser(BaseAutoGPTOutputParser):
"""Output parser for AutoGPT."""
def parse(self, text: str) -> AutoGPTAction: def parse(self, text: str) -> AutoGPTAction:
try: try:
parsed = json.loads(text, strict=False) parsed = json.loads(text, strict=False)

View File

@ -123,7 +123,7 @@ class PromptGenerator:
def get_prompt(tools: List[BaseTool]) -> str: def get_prompt(tools: List[BaseTool]) -> str:
"""This function generates a prompt string. """Generate a prompt string.
It includes various constraints, commands, resources, and performance evaluations. It includes various constraints, commands, resources, and performance evaluations.

View File

@ -2,6 +2,8 @@ from enum import Enum
class Constant(Enum): class Constant(Enum):
"""Enum for constants used in the CPAL."""
narrative_input = "narrative_input" narrative_input = "narrative_input"
chain_answer = "chain_answer" # natural language answer chain_answer = "chain_answer" # natural language answer
chain_data = "chain_data" # pydantic instance chain_data = "chain_data" # pydantic instance

View File

@ -55,6 +55,7 @@ def standardize_model_name(
) -> str: ) -> str:
""" """
Standardize the model name to a format that can be used in the OpenAI API. Standardize the model name to a format that can be used in the OpenAI API.
Args: Args:
model_name: Model name to standardize. model_name: Model name to standardize.
is_completion: Whether the model is used for completion or not. is_completion: Whether the model is used for completion or not.

View File

@ -53,9 +53,7 @@ GRAMMAR = """
@v_args(inline=True) @v_args(inline=True)
class QueryTransformer(Transformer): class QueryTransformer(Transformer):
"""Transforms a query string into an IR representation """Transforms a query string into an intermediate representation."""
(intermediate representation).
"""
def __init__( def __init__(
self, self,

View File

@ -33,11 +33,16 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel, ABC): class BaseChatModel(BaseLanguageModel, ABC):
"""Base class for chat models."""
cache: Optional[bool] = None cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity) verbose: bool = Field(default_factory=_get_verbosity)
"""Whether to print out response text.""" """Whether to print out response text."""
callbacks: Callbacks = Field(default=None, exclude=True) callbacks: Callbacks = Field(default=None, exclude=True)
"""Callbacks to add to the run trace."""
callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True)
"""Callback manager to add to the run trace."""
tags: Optional[List[str]] = Field(default=None, exclude=True) tags: Optional[List[str]] = Field(default=None, exclude=True)
"""Tags to add to the run trace.""" """Tags to add to the run trace."""
metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True)
@ -441,6 +446,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
class SimpleChatModel(BaseChatModel): class SimpleChatModel(BaseChatModel):
"""Simple Chat Model."""
def _generate( def _generate(
self, self,
messages: List[BaseMessage], messages: List[BaseMessage],

View File

@ -15,6 +15,7 @@ BLOCK_URL = NOTION_BASE_URL + "/blocks/{block_id}/children"
class NotionDBLoader(BaseLoader): class NotionDBLoader(BaseLoader):
"""Notion DB Loader. """Notion DB Loader.
Reads content from pages within a Notion Database. Reads content from pages within a Notion Database.
Args: Args:
integration_token (str): Notion integration token. integration_token (str): Notion integration token.

View File

@ -8,6 +8,8 @@ from langchain.document_loaders.blob_loaders import Blob
class ServerUnavailableException(Exception): class ServerUnavailableException(Exception):
"""Exception raised when the GROBID server is unavailable."""
pass pass

View File

@ -5,10 +5,13 @@ from langchain.schema import Document
def default_joiner(docs: List[Tuple[str, Any]]) -> str: def default_joiner(docs: List[Tuple[str, Any]]) -> str:
"""Default joiner for content columns."""
return "\n".join([doc[1] for doc in docs]) return "\n".join([doc[1] for doc in docs])
class ColumnNotFoundError(Exception): class ColumnNotFoundError(Exception):
"""Column not found error."""
def __init__(self, missing_key: str, query: str): def __init__(self, missing_key: str, query: str):
super().__init__(f'Column "{missing_key}" not selected in query:\n{query}') super().__init__(f'Column "{missing_key}" not selected in query:\n{query}')

View File

@ -36,6 +36,8 @@ class TrajectoryEval(NamedTuple):
class TrajectoryOutputParser(BaseOutputParser): class TrajectoryOutputParser(BaseOutputParser):
"""Trajectory output parser."""
@property @property
def _type(self) -> str: def _type(self) -> str:
return "agent_trajectory" return "agent_trajectory"

View File

@ -7,6 +7,8 @@ from langchain.vectorstores.base import VectorStoreRetriever
class AutoGPTMemory(BaseChatMemory): class AutoGPTMemory(BaseChatMemory):
"""Memory for AutoGPT."""
retriever: VectorStoreRetriever = Field(exclude=True) retriever: VectorStoreRetriever = Field(exclude=True)
"""VectorStoreRetriever object to connect to.""" """VectorStoreRetriever object to connect to."""

View File

@ -3,7 +3,7 @@ from langchain.schema.language_model import BaseLanguageModel
class TaskCreationChain(LLMChain): class TaskCreationChain(LLMChain):
"""Chain to generates tasks.""" """Chain generating tasks."""
@classmethod @classmethod
def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain: def from_llm(cls, llm: BaseLanguageModel, verbose: bool = True) -> LLMChain:

View File

@ -11,11 +11,10 @@ from langchain.schema.language_model import BaseLanguageModel
class GenerativeAgent(BaseModel): class GenerativeAgent(BaseModel):
"""A character with memory and innate characteristics.""" """An Agent as a character with memory and innate characteristics."""
name: str name: str
"""The character's name.""" """The character's name."""
age: Optional[int] = None age: Optional[int] = None
"""The optional age of the character.""" """The optional age of the character."""
traits: str = "N/A" traits: str = "N/A"
@ -29,13 +28,10 @@ class GenerativeAgent(BaseModel):
verbose: bool = False verbose: bool = False
summary: str = "" #: :meta private: summary: str = "" #: :meta private:
"""Stateful self-summary generated via reflection on the character's memory.""" """Stateful self-summary generated via reflection on the character's memory."""
summary_refresh_seconds: int = 3600 #: :meta private: summary_refresh_seconds: int = 3600 #: :meta private:
"""How frequently to re-generate the summary.""" """How frequently to re-generate the summary."""
last_refreshed: datetime = Field(default_factory=datetime.now) # : :meta private: last_refreshed: datetime = Field(default_factory=datetime.now) # : :meta private:
"""The last time the character's summary was regenerated.""" """The last time the character's summary was regenerated."""
daily_summaries: List[str] = Field(default_factory=list) # : :meta private: daily_summaries: List[str] = Field(default_factory=list) # : :meta private:
"""Summary of the events in the plan that the agent took.""" """Summary of the events in the plan that the agent took."""

View File

@ -14,24 +14,21 @@ logger = logging.getLogger(__name__)
class GenerativeAgentMemory(BaseMemory): class GenerativeAgentMemory(BaseMemory):
"""Memory for the generative agent."""
llm: BaseLanguageModel llm: BaseLanguageModel
"""The core language model.""" """The core language model."""
memory_retriever: TimeWeightedVectorStoreRetriever memory_retriever: TimeWeightedVectorStoreRetriever
"""The retriever to fetch related memories.""" """The retriever to fetch related memories."""
verbose: bool = False verbose: bool = False
reflection_threshold: Optional[float] = None reflection_threshold: Optional[float] = None
"""When aggregate_importance exceeds reflection_threshold, stop to reflect.""" """When aggregate_importance exceeds reflection_threshold, stop to reflect."""
current_plan: List[str] = [] current_plan: List[str] = []
"""The current plan of the agent.""" """The current plan of the agent."""
# A weight of 0.15 makes this less important than it # A weight of 0.15 makes this less important than it
# would be otherwise, relative to salience and time # would be otherwise, relative to salience and time
importance_weight: float = 0.15 importance_weight: float = 0.15
"""How much weight to assign the memory importance.""" """How much weight to assign the memory importance."""
aggregate_importance: float = 0.0 # : :meta private: aggregate_importance: float = 0.0 # : :meta private:
"""Track the sum of the 'importance' of recent memories. """Track the sum of the 'importance' of recent memories.

View File

@ -18,7 +18,7 @@ def import_jsonformer() -> jsonformer:
try: try:
import jsonformer import jsonformer
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"Could not import jsonformer python package. " "Could not import jsonformer python package. "
"Please install it with `pip install jsonformer`." "Please install it with `pip install jsonformer`."
) )
@ -26,6 +26,11 @@ def import_jsonformer() -> jsonformer:
class JsonFormer(HuggingFacePipeline): class JsonFormer(HuggingFacePipeline):
"""Jsonformer wrapped LLM using HuggingFace Pipeline API.
This pipeline is experimental and not yet stable.
"""
json_schema: dict = Field(..., description="The JSON Schema to complete.") json_schema: dict = Field(..., description="The JSON Schema to complete.")
max_new_tokens: int = Field( max_new_tokens: int = Field(
default=200, description="Maximum number of new tokens to generate." default=200, description="Maximum number of new tokens to generate."

View File

@ -24,7 +24,7 @@ def import_rellm() -> rellm:
try: try:
import rellm import rellm
except ImportError: except ImportError:
raise ValueError( raise ImportError(
"Could not import rellm python package. " "Could not import rellm python package. "
"Please install it with `pip install rellm`." "Please install it with `pip install rellm`."
) )
@ -32,6 +32,8 @@ def import_rellm() -> rellm:
class RELLM(HuggingFacePipeline): class RELLM(HuggingFacePipeline):
"""RELLM wrapped LLM using HuggingFace Pipeline API."""
regex: RegexPattern = Field(..., description="The structured format to complete.") regex: RegexPattern = Field(..., description="The structured format to complete.")
max_new_tokens: int = Field( max_new_tokens: int = Field(
default=200, description="Maximum number of new tokens to generate." default=200, description="Maximum number of new tokens to generate."

View File

@ -13,9 +13,14 @@ from langchain.experimental.plan_and_execute.schema import (
class PlanAndExecute(Chain): class PlanAndExecute(Chain):
"""Plan and execute a chain of steps."""
planner: BasePlanner planner: BasePlanner
"""The planner to use."""
executor: BaseExecutor executor: BaseExecutor
"""The executor to use."""
step_container: BaseStepContainer = Field(default_factory=ListStepContainer) step_container: BaseStepContainer = Field(default_factory=ListStepContainer)
"""The step container to use."""
input_key: str = "input" input_key: str = "input"
output_key: str = "output" output_key: str = "output"

View File

@ -9,6 +9,8 @@ from langchain.experimental.plan_and_execute.schema import StepResponse
class BaseExecutor(BaseModel): class BaseExecutor(BaseModel):
"""Base executor."""
@abstractmethod @abstractmethod
def step( def step(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
@ -19,11 +21,14 @@ class BaseExecutor(BaseModel):
async def astep( async def astep(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> StepResponse: ) -> StepResponse:
"""Take step.""" """Take async step."""
class ChainExecutor(BaseExecutor): class ChainExecutor(BaseExecutor):
"""Chain executor."""
chain: Chain chain: Chain
"""The chain to use."""
def step( def step(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any

View File

@ -9,6 +9,8 @@ from langchain.experimental.plan_and_execute.schema import Plan, PlanOutputParse
class BasePlanner(BaseModel): class BasePlanner(BaseModel):
"""Base planner."""
@abstractmethod @abstractmethod
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan: def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decide what to do.""" """Given input, decide what to do."""
@ -17,13 +19,18 @@ class BasePlanner(BaseModel):
async def aplan( async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan: ) -> Plan:
"""Given input, decide what to do.""" """Given input, asynchronously decide what to do."""
class LLMPlanner(BasePlanner): class LLMPlanner(BasePlanner):
"""LLM planner."""
llm_chain: LLMChain llm_chain: LLMChain
"""The LLM chain to use."""
output_parser: PlanOutputParser output_parser: PlanOutputParser
"""The output parser to use."""
stop: Optional[List] = None stop: Optional[List] = None
"""The stop list to use."""
def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan: def plan(self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any) -> Plan:
"""Given input, decide what to do.""" """Given input, decide what to do."""
@ -33,7 +40,7 @@ class LLMPlanner(BasePlanner):
async def aplan( async def aplan(
self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any self, inputs: dict, callbacks: Callbacks = None, **kwargs: Any
) -> Plan: ) -> Plan:
"""Given input, decide what to do.""" """Given input, asynchronously decide what to do."""
llm_response = await self.llm_chain.arun( llm_response = await self.llm_chain.arun(
**inputs, stop=self.stop, callbacks=callbacks **inputs, stop=self.stop, callbacks=callbacks
) )

View File

@ -24,6 +24,8 @@ SYSTEM_PROMPT = (
class PlanningOutputParser(PlanOutputParser): class PlanningOutputParser(PlanOutputParser):
"""Planning output parser."""
def parse(self, text: str) -> Plan: def parse(self, text: str) -> Plan:
steps = [Step(value=v) for v in re.split("\n\s*\d+\. ", text)[1:]] steps = [Step(value=v) for v in re.split("\n\s*\d+\. ", text)[1:]]
return Plan(steps=steps) return Plan(steps=steps)
@ -34,6 +36,7 @@ def load_chat_planner(
) -> LLMPlanner: ) -> LLMPlanner:
""" """
Load a chat planner. Load a chat planner.
Args: Args:
llm: Language model. llm: Language model.
system_prompt: System prompt. system_prompt: System prompt.

View File

@ -7,18 +7,29 @@ from langchain.schema import BaseOutputParser
class Step(BaseModel): class Step(BaseModel):
"""Step."""
value: str value: str
"""The value."""
class Plan(BaseModel): class Plan(BaseModel):
"""Plan."""
steps: List[Step] steps: List[Step]
"""The steps."""
class StepResponse(BaseModel): class StepResponse(BaseModel):
"""Step response."""
response: str response: str
"""The response."""
class BaseStepContainer(BaseModel): class BaseStepContainer(BaseModel):
"""Base step container."""
@abstractmethod @abstractmethod
def add_step(self, step: Step, step_response: StepResponse) -> None: def add_step(self, step: Step, step_response: StepResponse) -> None:
"""Add step and step response to the container.""" """Add step and step response to the container."""
@ -29,7 +40,10 @@ class BaseStepContainer(BaseModel):
class ListStepContainer(BaseStepContainer): class ListStepContainer(BaseStepContainer):
"""List step container."""
steps: List[Tuple[Step, StepResponse]] = Field(default_factory=list) steps: List[Tuple[Step, StepResponse]] = Field(default_factory=list)
"""The steps."""
def add_step(self, step: Step, step_response: StepResponse) -> None: def add_step(self, step: Step, step_response: StepResponse) -> None:
self.steps.append((step, step_response)) self.steps.append((step, step_response))
@ -42,6 +56,8 @@ class ListStepContainer(BaseStepContainer):
class PlanOutputParser(BaseOutputParser): class PlanOutputParser(BaseOutputParser):
"""Plan output parser."""
@abstractmethod @abstractmethod
def parse(self, text: str) -> Plan: def parse(self, text: str) -> Plan:
"""Parse into a plan.""" """Parse into a plan."""

View File

@ -259,7 +259,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
class CharacterTextSplitter(TextSplitter): class CharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters.""" """Splitting text that looks at characters."""
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
@ -290,7 +290,7 @@ class HeaderType(TypedDict):
class MarkdownHeaderTextSplitter: class MarkdownHeaderTextSplitter:
"""Implementation of splitting markdown files based on specified headers.""" """Splitting markdown files based on specified headers."""
def __init__( def __init__(
self, headers_to_split_on: List[Tuple[str, str]], return_each_line: bool = False self, headers_to_split_on: List[Tuple[str, str]], return_each_line: bool = False
@ -443,7 +443,7 @@ class Tokenizer:
def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]: def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks using tokenizer."""
splits: List[str] = [] splits: List[str] = []
input_ids = tokenizer.encode(text) input_ids = tokenizer.encode(text)
start_idx = 0 start_idx = 0
@ -458,7 +458,7 @@ def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> List[str]:
class TokenTextSplitter(TextSplitter): class TokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens.""" """Splitting text to tokens using model tokenizer."""
def __init__( def __init__(
self, self,
@ -506,7 +506,7 @@ class TokenTextSplitter(TextSplitter):
class SentenceTransformersTokenTextSplitter(TextSplitter): class SentenceTransformersTokenTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at tokens.""" """Splitting text to tokens using sentence model tokenizer."""
def __init__( def __init__(
self, self,
@ -599,7 +599,7 @@ class Language(str, Enum):
class RecursiveCharacterTextSplitter(TextSplitter): class RecursiveCharacterTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at characters. """Splitting text by recursively look at characters.
Recursively tries to split by different characters to find one Recursively tries to split by different characters to find one
that works. that works.
@ -1004,7 +1004,7 @@ class RecursiveCharacterTextSplitter(TextSplitter):
class NLTKTextSplitter(TextSplitter): class NLTKTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using NLTK.""" """Splitting text using NLTK package."""
def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
"""Initialize the NLTK splitter.""" """Initialize the NLTK splitter."""
@ -1027,7 +1027,7 @@ class NLTKTextSplitter(TextSplitter):
class SpacyTextSplitter(TextSplitter): class SpacyTextSplitter(TextSplitter):
"""Implementation of splitting text that looks at sentences using Spacy. """Splitting text using Spacy package.
Per default, Spacy's `en_core_web_sm` model is used. For a faster, but Per default, Spacy's `en_core_web_sm` model is used. For a faster, but

View File

@ -8,9 +8,14 @@ from langchain.schema import Document
class BraveSearchWrapper(BaseModel): class BraveSearchWrapper(BaseModel):
"""Wrapper around the Brave search engine."""
api_key: str api_key: str
"""The API key to use for the Brave search engine."""
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
"""Additional keyword arguments to pass to the search request."""
base_url = "https://api.search.brave.com/res/v1/web/search" base_url = "https://api.search.brave.com/res/v1/web/search"
"""The base URL for the Brave search engine."""
def run(self, query: str) -> str: def run(self, query: str) -> str:
"""Query the Brave search engine and return the results as a JSON string. """Query the Brave search engine and return the results as a JSON string.

View File

@ -10,6 +10,8 @@ from langchain.utils import get_from_dict_or_env
class DataForSeoAPIWrapper(BaseModel): class DataForSeoAPIWrapper(BaseModel):
"""Wrapper around the DataForSeo API."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -25,13 +27,21 @@ class DataForSeoAPIWrapper(BaseModel):
"se_type": "organic", "se_type": "organic",
} }
) )
"""Default parameters to use for the DataForSEO SERP API."""
params: dict = Field(default={}) params: dict = Field(default={})
"""Additional parameters to pass to the DataForSEO SERP API."""
api_login: Optional[str] = None api_login: Optional[str] = None
"""The API login to use for the DataForSEO SERP API."""
api_password: Optional[str] = None api_password: Optional[str] = None
"""The API password to use for the DataForSEO SERP API."""
json_result_types: Optional[list] = None json_result_types: Optional[list] = None
"""The JSON result types."""
json_result_fields: Optional[list] = None json_result_fields: Optional[list] = None
"""The JSON result fields."""
top_count: Optional[int] = None top_count: Optional[int] = None
"""The number of top results to return."""
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None
"""The aiohttp session to use for the DataForSEO SERP API."""
@root_validator() @root_validator()
def validate_environment(cls, values: Dict) -> Dict: def validate_environment(cls, values: Dict) -> Dict:

View File

@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def warn_once() -> None: def warn_once() -> None:
# Warn that the PythonREPL """Warn once about the dangers of PythonREPL."""
logger.warning("Python REPL can execute arbitrary code. Use with caution.") logger.warning("Python REPL can execute arbitrary code. Use with caution.")

View File

@ -166,6 +166,8 @@ def _get_search_client(
class AzureSearch(VectorStore): class AzureSearch(VectorStore):
"""Azure Cognitive Search vector store."""
def __init__( def __init__(
self, self,
azure_search_endpoint: str, azure_search_endpoint: str,
@ -481,9 +483,15 @@ class AzureSearch(VectorStore):
class AzureSearchVectorStoreRetriever(BaseRetriever): class AzureSearchVectorStoreRetriever(BaseRetriever):
"""Retriever that uses Azure Search to find similar documents."""
vectorstore: AzureSearch vectorstore: AzureSearch
"""Azure Search instance used to find similar documents."""
search_type: str = "hybrid" search_type: str = "hybrid"
"""Type of search to perform. Options are "similarity", "hybrid",
"semantic_hybrid"."""
k: int = 4 k: int = 4
"""Number of documents to return."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -460,9 +460,14 @@ class VectorStore(ABC):
class VectorStoreRetriever(BaseRetriever): class VectorStoreRetriever(BaseRetriever):
"""Retriever class for VectorStore."""
vectorstore: VectorStore vectorstore: VectorStore
"""VectorStore to use for retrieval."""
search_type: str = "similarity" search_type: str = "similarity"
"""Type of search to perform. Defaults to "similarity"."""
search_kwargs: dict = Field(default_factory=dict) search_kwargs: dict = Field(default_factory=dict)
"""Keyword arguments to pass to the search function."""
allowed_search_types: ClassVar[Collection[str]] = ( allowed_search_types: ClassVar[Collection[str]] = (
"similarity", "similarity",
"similarity_score_threshold", "similarity_score_threshold",

View File

@ -94,6 +94,7 @@ class QueryResult:
class PGEmbedding(VectorStore): class PGEmbedding(VectorStore):
""" """
VectorStore implementation using Postgres and the pg_embedding extension. VectorStore implementation using Postgres and the pg_embedding extension.
pg_embedding uses sequential scan by default. but you can create a HNSW index pg_embedding uses sequential scan by default. but you can create a HNSW index
using the create_hnsw_index method. using the create_hnsw_index method.
- `connection_string` is a postgres connection string. - `connection_string` is a postgres connection string.

View File

@ -612,10 +612,16 @@ class Redis(VectorStore):
class RedisVectorStoreRetriever(VectorStoreRetriever): class RedisVectorStoreRetriever(VectorStoreRetriever):
"""Retriever for Redis VectorStore."""
vectorstore: Redis vectorstore: Redis
"""Redis VectorStore."""
search_type: str = "similarity" search_type: str = "similarity"
"""Type of search to perform. Can be either 'similarity' or 'similarity_limit'."""
k: int = 4 k: int = 4
"""Number of documents to return."""
score_threshold: float = 0.4 score_threshold: float = 0.4
"""Score threshold for similarity_limit search."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -9,6 +9,9 @@ from langchain.utils.math import cosine_similarity
class DistanceStrategy(str, Enum): class DistanceStrategy(str, Enum):
"""Enumerator of the Distance strategies for calculating distances
between vectors."""
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE" EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT" MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
DOT_PRODUCT = "DOT_PRODUCT" DOT_PRODUCT = "DOT_PRODUCT"

View File

@ -412,7 +412,10 @@ class Vectara(VectorStore):
class VectaraRetriever(VectorStoreRetriever): class VectaraRetriever(VectorStoreRetriever):
"""Retriever class for Vectara."""
vectorstore: Vectara vectorstore: Vectara
"""Vectara vectorstore."""
search_kwargs: dict = Field( search_kwargs: dict = Field(
default_factory=lambda: { default_factory=lambda: {
"lambda_val": 0.025, "lambda_val": 0.025,