docstrings retrievers (#7858)

Added/updated docstrings `retrievers`

@baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-17 17:47:17 -07:00 committed by GitHub
parent 5b4d53e8ef
commit 74b701f42b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 235 additions and 64 deletions

View File

@ -10,7 +10,8 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
class ArxivRetriever(BaseRetriever, ArxivAPIWrapper): class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
""" """
It is effectively a wrapper for ArxivAPIWrapper. Retriever for Arxiv.
It wraps load() to get_relevant_documents(). It wraps load() to get_relevant_documents().
It uses all ArxivAPIWrapper arguments without any change. It uses all ArxivAPIWrapper arguments without any change.
""" """

View File

@ -1,4 +1,4 @@
"""Retriever wrapper for Azure Cognitive Search.""" """Retriever for the Azure Cognitive Search service."""
from __future__ import annotations from __future__ import annotations
@ -18,7 +18,7 @@ from langchain.utils import get_from_dict_or_env
class AzureCognitiveSearchRetriever(BaseRetriever): class AzureCognitiveSearchRetriever(BaseRetriever):
"""Wrapper around Azure Cognitive Search.""" """Retriever for the Azure Cognitive Search service."""
service_name: str = "" service_name: str = ""
"""Name of Azure Cognitive Search service""" """Name of Azure Cognitive Search service"""

View File

@ -19,10 +19,16 @@ def default_preprocessing_func(text: str) -> List[str]:
class BM25Retriever(BaseRetriever): class BM25Retriever(BaseRetriever):
"""BM25 Retriever without elastic search."""
vectorizer: Any vectorizer: Any
""" BM25 vectorizer."""
docs: List[Document] docs: List[Document]
""" List of documents."""
k: int = 4 k: int = 4
""" Number of documents to return."""
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
""" Preprocessing function to use on the text before BM25 vectorization."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -38,6 +44,18 @@ class BM25Retriever(BaseRetriever):
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any, **kwargs: Any,
) -> BM25Retriever: ) -> BM25Retriever:
"""
Create a BM25Retriever from a list of texts.
Args:
texts: A list of texts to vectorize.
metadatas: A list of metadata dicts to associate with each text.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
try: try:
from rank_bm25 import BM25Okapi from rank_bm25 import BM25Okapi
except ImportError: except ImportError:
@ -64,6 +82,17 @@ class BM25Retriever(BaseRetriever):
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
**kwargs: Any, **kwargs: Any,
) -> BM25Retriever: ) -> BM25Retriever:
"""
Create a BM25Retriever from a list of Documents.
Args:
documents: A list of Documents to vectorize.
bm25_params: Parameters to pass to the BM25 vectorizer.
preprocess_func: A function to preprocess each text before vectorization.
**kwargs: Any other arguments to pass to the retriever.
Returns:
A BM25Retriever instance.
"""
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts( return cls.from_texts(
texts=texts, texts=texts,

View File

@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document
class ChaindeskRetriever(BaseRetriever): class ChaindeskRetriever(BaseRetriever):
"""Retriever that uses the Chaindesk API.""" """Retriever for the Chaindesk API."""
datastore_url: str datastore_url: str
top_k: Optional[int] top_k: Optional[int]

View File

@ -13,16 +13,24 @@ from langchain.schema import BaseRetriever, Document
class ChatGPTPluginRetriever(BaseRetriever): class ChatGPTPluginRetriever(BaseRetriever):
"""Retrieves documents from a ChatGPT plugin."""
url: str url: str
"""URL of the ChatGPT plugin."""
bearer_token: str bearer_token: str
"""Bearer token for the ChatGPT plugin."""
top_k: int = 3 top_k: int = 3
"""Number of documents to return."""
filter: Optional[dict] = None filter: Optional[dict] = None
"""Filter to apply to the results."""
aiosession: Optional[aiohttp.ClientSession] = None aiosession: Optional[aiohttp.ClientSession] = None
"""Aiohttp session to use for requests."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
arbitrary_types_allowed = True arbitrary_types_allowed = True
"""Allow arbitrary types."""
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -1,5 +1,3 @@
"""Retriever that wraps a base retriever and filters the results."""
from typing import Any, List from typing import Any, List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (

View File

@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document
class DataberryRetriever(BaseRetriever): class DataberryRetriever(BaseRetriever):
"""Retriever that uses the Databerry API.""" """Retriever for the Databerry API."""
datastore_url: str datastore_url: str
top_k: Optional[int] top_k: Optional[int]

View File

@ -21,7 +21,7 @@ class SearchType(str, Enum):
class DocArrayRetriever(BaseRetriever): class DocArrayRetriever(BaseRetriever):
""" """
Retriever class for DocArray Document Indices. Retriever for DocArray Document Indices.
Currently, supports 5 backends: Currently, supports 5 backends:
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex, InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,

View File

@ -42,6 +42,9 @@ def _get_default_chain_prompt() -> PromptTemplate:
class LLMChainExtractor(BaseDocumentCompressor): class LLMChainExtractor(BaseDocumentCompressor):
"""DocumentCompressor that uses an LLM chain to extract
the relevant parts of documents."""
llm_chain: LLMChain llm_chain: LLMChain
"""LLM wrapper to use for compressing documents.""" """LLM wrapper to use for compressing documents."""

View File

@ -68,6 +68,16 @@ class LLMChainFilter(BaseDocumentCompressor):
prompt: Optional[BasePromptTemplate] = None, prompt: Optional[BasePromptTemplate] = None,
**kwargs: Any **kwargs: Any
) -> "LLMChainFilter": ) -> "LLMChainFilter":
"""Create a LLMChainFilter from a language model.
Args:
llm: The language model to use for filtering.
prompt: The prompt to use for the filter.
**kwargs: Additional arguments to pass to the constructor.
Returns:
A LLMChainFilter that uses the given language model.
"""
_prompt = prompt if prompt is not None else _get_default_chain_prompt() _prompt = prompt if prompt is not None else _get_default_chain_prompt()
llm_chain = LLMChain(llm=llm, prompt=_prompt) llm_chain = LLMChain(llm=llm, prompt=_prompt)
return cls(llm_chain=llm_chain, **kwargs) return cls(llm_chain=llm_chain, **kwargs)

View File

@ -21,9 +21,14 @@ else:
class CohereRerank(BaseDocumentCompressor): class CohereRerank(BaseDocumentCompressor):
"""DocumentCompressor that uses Cohere's rerank API to compress documents."""
client: Client client: Client
"""Cohere client to use for compressing documents."""
top_n: int = 3 top_n: int = 3
"""Number of documents to return."""
model: str = "rerank-english-v2.0" model: str = "rerank-english-v2.0"
"""Model to use for reranking."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -54,6 +59,17 @@ class CohereRerank(BaseDocumentCompressor):
query: str, query: str,
callbacks: Optional[Callbacks] = None, callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]: ) -> Sequence[Document]:
"""
Compress documents using Cohere's rerank API.
Args:
documents: A sequence of documents to compress.
query: The query to use for compressing the documents.
callbacks: Callbacks to run during the compression process.
Returns:
A sequence of compressed documents.
"""
if len(documents) == 0: # to avoid empty api call if len(documents) == 0: # to avoid empty api call
return [] return []
doc_list = list(documents) doc_list = list(documents)

View File

@ -1,4 +1,3 @@
"""Document compressor that uses embeddings to drop documents unrelated to the query."""
from typing import Callable, Dict, Optional, Sequence from typing import Callable, Dict, Optional, Sequence
import numpy as np import numpy as np
@ -18,6 +17,9 @@ from langchain.schema import Document
class EmbeddingsFilter(BaseDocumentCompressor): class EmbeddingsFilter(BaseDocumentCompressor):
"""Document compressor that uses embeddings to drop documents
unrelated to the query."""
embeddings: Embeddings embeddings: Embeddings
"""Embeddings to use for embedding document contents and queries.""" """Embeddings to use for embedding document contents and queries."""
similarity_fn: Callable = cosine_similarity similarity_fn: Callable = cosine_similarity

View File

@ -14,8 +14,7 @@ from langchain.schema import BaseRetriever
class ElasticSearchBM25Retriever(BaseRetriever): class ElasticSearchBM25Retriever(BaseRetriever):
"""Wrapper around Elasticsearch using BM25 as a retrieval method. """Retriever for the Elasticsearch using BM25 as a retrieval method.
To connect to an Elasticsearch instance that requires login credentials, To connect to an Elasticsearch instance that requires login credentials,
including Elastic Cloud, use the Elasticsearch URL format including Elastic Cloud, use the Elasticsearch URL format
@ -41,12 +40,26 @@ class ElasticSearchBM25Retriever(BaseRetriever):
""" """
client: Any client: Any
"""Elasticsearch client."""
index_name: str index_name: str
"""Name of the index to use in Elasticsearch."""
@classmethod @classmethod
def create( def create(
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75 cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
) -> ElasticSearchBM25Retriever: ) -> ElasticSearchBM25Retriever:
"""
Create a ElasticSearchBM25Retriever from a list of texts.
Args:
elasticsearch_url: URL of the Elasticsearch instance to connect to.
index_name: Name of the index to use in Elasticsearch.
k1: BM25 parameter k1.
b: BM25 parameter b.
Returns:
"""
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
# Create an Elasticsearch client instance # Create an Elasticsearch client instance

View File

@ -51,24 +51,38 @@ class Highlight(BaseModel, extra=Extra.allow):
class TextWithHighLights(BaseModel, extra=Extra.allow): class TextWithHighLights(BaseModel, extra=Extra.allow):
"""Text with highlights."""
Text: str Text: str
"""The text."""
Highlights: Optional[Any] Highlights: Optional[Any]
"""The highlights."""
class AdditionalResultAttributeValue(BaseModel, extra=Extra.allow): class AdditionalResultAttributeValue(BaseModel, extra=Extra.allow):
"""The value of an additional result attribute."""
TextWithHighlightsValue: TextWithHighLights TextWithHighlightsValue: TextWithHighLights
"""The text with highlights value."""
class AdditionalResultAttribute(BaseModel, extra=Extra.allow): class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
"""An additional result attribute."""
Key: str Key: str
"""The key of the attribute."""
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"] ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
"""The type of the value."""
Value: AdditionalResultAttributeValue Value: AdditionalResultAttributeValue
"""The value of the attribute."""
def get_value_text(self) -> str: def get_value_text(self) -> str:
return self.Value.TextWithHighlightsValue.Text return self.Value.TextWithHighlightsValue.Text
class QueryResultItem(BaseModel, extra=Extra.allow): class QueryResultItem(BaseModel, extra=Extra.allow):
"""A query result item."""
DocumentId: str DocumentId: str
DocumentTitle: TextWithHighLights DocumentTitle: TextWithHighLights
DocumentURI: Optional[str] DocumentURI: Optional[str]
@ -111,9 +125,19 @@ class QueryResultItem(BaseModel, extra=Extra.allow):
class QueryResult(BaseModel, extra=Extra.allow): class QueryResult(BaseModel, extra=Extra.allow):
"""A query result."""
ResultItems: List[QueryResultItem] ResultItems: List[QueryResultItem]
def get_top_k_docs(self, top_n: int) -> List[Document]: def get_top_k_docs(self, top_n: int) -> List[Document]:
"""Gets the top k documents.
Args:
top_n: The number of documents to return.
Returns:
The top k documents.
"""
items_len = len(self.ResultItems) items_len = len(self.ResultItems)
count = items_len if items_len < top_n else top_n count = items_len if items_len < top_n else top_n
docs = [self.ResultItems[i].to_doc() for i in range(0, count)] docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
@ -122,24 +146,42 @@ class QueryResult(BaseModel, extra=Extra.allow):
class DocumentAttributeValue(BaseModel, extra=Extra.allow): class DocumentAttributeValue(BaseModel, extra=Extra.allow):
"""The value of a document attribute."""
DateValue: Optional[str] DateValue: Optional[str]
"""The date value."""
LongValue: Optional[int] LongValue: Optional[int]
"""The long value."""
StringListValue: Optional[List[str]] StringListValue: Optional[List[str]]
"""The string list value."""
StringValue: Optional[str] StringValue: Optional[str]
"""The string value."""
class DocumentAttribute(BaseModel, extra=Extra.allow): class DocumentAttribute(BaseModel, extra=Extra.allow):
"""A document attribute."""
Key: str Key: str
"""The key of the attribute."""
Value: DocumentAttributeValue Value: DocumentAttributeValue
"""The value of the attribute."""
class RetrieveResultItem(BaseModel, extra=Extra.allow): class RetrieveResultItem(BaseModel, extra=Extra.allow):
"""A retrieve result item."""
Content: Optional[str] Content: Optional[str]
"""The content of the item."""
DocumentAttributes: Optional[List[DocumentAttribute]] = [] DocumentAttributes: Optional[List[DocumentAttribute]] = []
"""The document attributes."""
DocumentId: Optional[str] DocumentId: Optional[str]
"""The document ID."""
DocumentTitle: Optional[str] DocumentTitle: Optional[str]
"""The document title."""
DocumentURI: Optional[str] DocumentURI: Optional[str]
"""The document URI."""
Id: Optional[str] Id: Optional[str]
"""The ID of the item."""
def get_excerpt(self) -> str: def get_excerpt(self) -> str:
if not self.Content: if not self.Content:
@ -156,8 +198,12 @@ class RetrieveResultItem(BaseModel, extra=Extra.allow):
class RetrieveResult(BaseModel, extra=Extra.allow): class RetrieveResult(BaseModel, extra=Extra.allow):
"""A retrieve result."""
QueryId: str QueryId: str
"""The ID of the query."""
ResultItems: List[RetrieveResultItem] ResultItems: List[RetrieveResultItem]
"""The result items."""
def get_top_k_docs(self, top_n: int) -> List[Document]: def get_top_k_docs(self, top_n: int) -> List[Document]:
items_len = len(self.ResultItems) items_len = len(self.ResultItems)
@ -168,7 +214,7 @@ class RetrieveResult(BaseModel, extra=Extra.allow):
class AmazonKendraRetriever(BaseRetriever): class AmazonKendraRetriever(BaseRetriever):
"""Retriever class to query documents from Amazon Kendra Index. """Retriever for the Amazon Kendra Index.
Args: Args:
index_id: Kendra index id index_id: Kendra index id

View File

@ -36,10 +36,15 @@ class KNNRetriever(BaseRetriever):
"""KNN Retriever.""" """KNN Retriever."""
embeddings: Embeddings embeddings: Embeddings
"""Embeddings model to use."""
index: Any index: Any
"""Index of embeddings."""
texts: List[str] texts: List[str]
"""List of texts to index."""
k: int = 4 k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
class Config: class Config:

View File

@ -10,10 +10,13 @@ from langchain.schema import BaseRetriever, Document
class LlamaIndexRetriever(BaseRetriever): class LlamaIndexRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex data structure.""" """Retriever for the question-answering with sources over
an LlamaIndex data structure."""
index: Any index: Any
"""LlamaIndex index to query."""
query_kwargs: Dict = Field(default_factory=dict) query_kwargs: Dict = Field(default_factory=dict)
"""Keyword arguments to pass to the query method."""
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun
@ -46,10 +49,13 @@ class LlamaIndexRetriever(BaseRetriever):
class LlamaIndexGraphRetriever(BaseRetriever): class LlamaIndexGraphRetriever(BaseRetriever):
"""Question-answering with sources over an LlamaIndex graph data structure.""" """Retriever for question-answering with sources over an LlamaIndex
graph data structure."""
graph: Any graph: Any
"""LlamaIndex graph to query."""
query_configs: List[Dict] = Field(default_factory=list) query_configs: List[Dict] = Field(default_factory=list)
"""List of query configs to pass to the query method."""
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -8,14 +8,10 @@ from langchain.schema import BaseRetriever, Document
class MergerRetriever(BaseRetriever): class MergerRetriever(BaseRetriever):
""" """Retriever that merges the results of multiple retrievers."""
This class merges the results of multiple retrievers.
Args:
retrievers: A list of retrievers to merge.
"""
retrievers: List[BaseRetriever] retrievers: List[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents( def _get_relevant_documents(
self, self,

View File

@ -13,8 +13,9 @@ class MetalRetriever(BaseRetriever):
"""Retriever that uses the Metal API.""" """Retriever that uses the Metal API."""
client: Any client: Any
"""The Metal client to use."""
params: Optional[dict] = None params: Optional[dict] = None
"""The parameters to pass to the Metal client."""
@root_validator(pre=True) @root_validator(pre=True)
def validate_client(cls, values: dict) -> dict: def validate_client(cls, values: dict) -> dict:

View File

@ -17,10 +17,15 @@ logger = logging.getLogger(__name__)
class LineList(BaseModel): class LineList(BaseModel):
"""List of lines."""
lines: List[str] = Field(description="Lines of text") lines: List[str] = Field(description="Lines of text")
"""List of lines."""
class LineListOutputParser(PydanticOutputParser): class LineListOutputParser(PydanticOutputParser):
"""Output parser for a list of lines."""
def __init__(self) -> None: def __init__(self) -> None:
super().__init__(pydantic_object=LineList) super().__init__(pydantic_object=LineList)

View File

@ -99,12 +99,19 @@ def create_index(
class PineconeHybridSearchRetriever(BaseRetriever): class PineconeHybridSearchRetriever(BaseRetriever):
"""Pinecone Hybrid Search Retriever."""
embeddings: Embeddings embeddings: Embeddings
"""Embeddings model to use."""
"""description""" """description"""
sparse_encoder: Any sparse_encoder: Any
"""Sparse encoder to use."""
index: Any index: Any
"""Pinecone index to use."""
top_k: int = 4 top_k: int = 4
"""Number of documents to return."""
alpha: float = 0.5 alpha: float = 0.5
"""Alpha value for hybrid search."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -1,4 +1,3 @@
"""A retriever that uses PubMed API to retrieve documents."""
from typing import List from typing import List
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
@ -10,8 +9,8 @@ from langchain.utilities.pupmed import PubMedAPIWrapper
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper): class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
""" """Retriever for PubMed API.
It is effectively a wrapper for PubMedAPIWrapper.
It wraps load() to get_relevant_documents(). It wraps load() to get_relevant_documents().
It uses all PubMedAPIWrapper arguments without any change. It uses all PubMedAPIWrapper arguments without any change.
""" """

View File

@ -11,12 +11,20 @@ from langchain.schema import BaseRetriever, Document
class RemoteLangChainRetriever(BaseRetriever): class RemoteLangChainRetriever(BaseRetriever):
"""Retriever for remote LangChain API."""
url: str url: str
"""URL of the remote LangChain API."""
headers: Optional[dict] = None headers: Optional[dict] = None
"""Headers to use for the request."""
input_key: str = "message" input_key: str = "message"
"""Key to use for the input in the request."""
response_key: str = "response" response_key: str = "response"
"""Key to use for the response in the request."""
page_content_key: str = "page_content" page_content_key: str = "page_content"
"""Key to use for the page content in the response."""
metadata_key: str = "metadata" metadata_key: str = "metadata"
"""Key to use for the metadata in the response."""
def _get_relevant_documents( def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun self, query: str, *, run_manager: CallbackManagerForRetrieverRun

View File

@ -52,7 +52,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
class SelfQueryRetriever(BaseRetriever, BaseModel): class SelfQueryRetriever(BaseRetriever, BaseModel):
"""Retriever that wraps around a vector store and uses an LLM to generate """Retriever that uses a vector store and an LLM to generate
the vector store queries.""" the vector store queries."""
vectorstore: VectorStore vectorstore: VectorStore

View File

@ -1,4 +1,3 @@
"""Logic for converting internal query language to a valid Chroma query."""
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import ( from langchain.chains.query_constructor.ir import (
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
class ChromaTranslator(Visitor): class ChromaTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters.""" """Translate internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR] allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators.""" """Subset of allowed logical operators."""

View File

@ -48,7 +48,7 @@ def FUNCTION_COMPOSER(op_name: str) -> Callable:
class MyScaleTranslator(Visitor): class MyScaleTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters.""" """Translate internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
"""Subset of allowed logical operators.""" """Subset of allowed logical operators."""

View File

@ -1,4 +1,3 @@
"""Logic for converting internal query language to a valid Pinecone query."""
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import ( from langchain.chains.query_constructor.ir import (
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
class PineconeTranslator(Visitor): class PineconeTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters.""" """Translate the internal query language elements to valid filters."""
allowed_comparators = ( allowed_comparators = (
Comparator.EQ, Comparator.EQ,

View File

@ -1,4 +1,3 @@
"""Logic for converting internal query language to a valid Qdrant query."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
@ -17,7 +16,7 @@ if TYPE_CHECKING:
class QdrantTranslator(Visitor): class QdrantTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters.""" """Translate the internal query language elements to valid filters."""
allowed_comparators = ( allowed_comparators = (
Comparator.EQ, Comparator.EQ,

View File

@ -1,4 +1,3 @@
"""Logic for converting internal query language to a valid Weaviate query."""
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
from langchain.chains.query_constructor.ir import ( from langchain.chains.query_constructor.ir import (
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
class WeaviateTranslator(Visitor): class WeaviateTranslator(Visitor):
"""Logic for converting internal query language elements to valid filters.""" """Translate the internal query language elements to valid filters."""
allowed_operators = [Operator.AND, Operator.OR] allowed_operators = [Operator.AND, Operator.OR]
"""Subset of allowed logical operators.""" """Subset of allowed logical operators."""

View File

@ -1,7 +1,3 @@
"""SMV Retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
from __future__ import annotations from __future__ import annotations
import concurrent.futures import concurrent.futures
@ -20,6 +16,7 @@ from langchain.schema import BaseRetriever, Document
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
""" """
Create an index of embeddings for a list of contexts. Create an index of embeddings for a list of contexts.
Args: Args:
contexts: List of contexts to embed. contexts: List of contexts to embed.
embeddings: Embeddings model to use. embeddings: Embeddings model to use.
@ -32,13 +29,22 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
class SVMRetriever(BaseRetriever): class SVMRetriever(BaseRetriever):
"""SVM Retriever.""" """SVM Retriever.
Largely based on
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
"""
embeddings: Embeddings embeddings: Embeddings
"""Embeddings model to use."""
index: Any index: Any
"""Index of embeddings."""
texts: List[str] texts: List[str]
"""List of texts to index."""
k: int = 4 k: int = 4
"""Number of results to return."""
relevancy_threshold: Optional[float] = None relevancy_threshold: Optional[float] = None
"""Threshold for relevancy."""
class Config: class Config:

View File

@ -1,8 +1,3 @@
"""TF-IDF Retriever.
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable, List, Optional
@ -15,10 +10,20 @@ from langchain.schema import BaseRetriever, Document
class TFIDFRetriever(BaseRetriever): class TFIDFRetriever(BaseRetriever):
"""TF-IDF Retriever.
Largely based on
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb
"""
vectorizer: Any vectorizer: Any
"""TF-IDF vectorizer."""
docs: List[Document] docs: List[Document]
"""Documents."""
tfidf_array: Any tfidf_array: Any
"""TF-IDF array."""
k: int = 4 k: int = 4
"""Number of documents to return."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""

View File

@ -1,5 +1,3 @@
"""Retriever that combines embedding similarity with recency in retrieving values."""
import datetime import datetime
from copy import deepcopy from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
@ -20,7 +18,8 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f
class TimeWeightedVectorStoreRetriever(BaseRetriever): class TimeWeightedVectorStoreRetriever(BaseRetriever):
"""Retriever combining embedding similarity with recency.""" """Retriever that combines embedding similarity with
recency in retrieving values."""
vectorstore: VectorStore vectorstore: VectorStore
"""The vectorstore to store documents and determine salience.""" """The vectorstore to store documents and determine salience."""

View File

@ -1,5 +1,3 @@
"""Wrapper for retrieving documents from Vespa."""
from __future__ import annotations from __future__ import annotations
import json import json
@ -16,12 +14,16 @@ if TYPE_CHECKING:
class VespaRetriever(BaseRetriever): class VespaRetriever(BaseRetriever):
"""Retriever that uses the Vespa.""" """Retriever that uses Vespa."""
app: Vespa app: Vespa
"""Vespa application to query."""
body: Dict body: Dict
"""Body of the query."""
content_field: str content_field: str
"""Name of the content field."""
metadata_fields: Sequence[str] metadata_fields: Sequence[str]
"""Names of the metadata fields."""
def _query(self, body: Dict) -> List[Document]: def _query(self, body: Dict) -> List[Document]:
response = self.app.query(body) response = self.app.query(body)
@ -97,6 +99,9 @@ class VespaRetriever(BaseRetriever):
yql (Optional[str]): Full YQL query to be used. Should not be specified yql (Optional[str]): Full YQL query to be used. Should not be specified
if _filter or sources are specified. Defaults to None. if _filter or sources are specified. Defaults to None.
kwargs (Any): Keyword arguments added to query body. kwargs (Any): Keyword arguments added to query body.
Returns:
VespaRetriever: Instantiated VespaRetriever.
""" """
try: try:
from vespa.application import Vespa from vespa.application import Vespa

View File

@ -1,5 +1,3 @@
"""Wrapper around weaviate vector database."""
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, cast from typing import Any, Dict, List, Optional, cast
@ -16,7 +14,7 @@ from langchain.schema import BaseRetriever
class WeaviateHybridSearchRetriever(BaseRetriever): class WeaviateHybridSearchRetriever(BaseRetriever):
"""Retriever that uses Weaviate's hybrid search to retrieve documents.""" """Retriever for the Weaviate's hybrid search."""
client: Any client: Any
"""keyword arguments to pass to the Weaviate client.""" """keyword arguments to pass to the Weaviate client."""

View File

@ -9,8 +9,8 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper): class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
""" """Retriever for Wikipedia API.
It is effectively a wrapper for WikipediaAPIWrapper.
It wraps load() to get_relevant_documents(). It wraps load() to get_relevant_documents().
It uses all WikipediaAPIWrapper arguments without any change. It uses all WikipediaAPIWrapper arguments without any change.
""" """

View File

@ -15,8 +15,9 @@ if TYPE_CHECKING:
class ZepRetriever(BaseRetriever): class ZepRetriever(BaseRetriever):
"""A Retriever implementation for the Zep long-term memory store. Search your """Retriever for the Zep long-term memory store.
user's long-term chat history with Zep.
Search your user's long-term chat history with Zep.
Note: You will need to provide the user's `session_id` to use this retriever. Note: You will need to provide the user's `session_id` to use this retriever.
@ -30,10 +31,11 @@ class ZepRetriever(BaseRetriever):
""" """
zep_client: Any zep_client: Any
"""Zep client."""
session_id: str session_id: str
"""Zep session ID."""
top_k: Optional[int] top_k: Optional[int]
"""Number of documents to return."""
@root_validator(pre=True) @root_validator(pre=True)
def create_client(cls, values: dict) -> dict: def create_client(cls, values: dict) -> dict:

View File

@ -1,4 +1,3 @@
"""Zilliz Retriever"""
import warnings import warnings
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -16,16 +15,22 @@ from langchain.vectorstores.zilliz import Zilliz
class ZillizRetriever(BaseRetriever): class ZillizRetriever(BaseRetriever):
"""Retriever that uses the Zilliz API.""" """Retriever for the Zilliz API."""
embedding_function: Embeddings embedding_function: Embeddings
"""The underlying embedding function from which documents will be retrieved."""
collection_name: str = "LangChainCollection" collection_name: str = "LangChainCollection"
"""The name of the collection in Zilliz."""
connection_args: Optional[Dict[str, Any]] = None connection_args: Optional[Dict[str, Any]] = None
"""The connection arguments for the Zilliz client."""
consistency_level: str = "Session" consistency_level: str = "Session"
"""The consistency level for the Zilliz client."""
search_params: Optional[dict] = None search_params: Optional[dict] = None
"""The search parameters for the Zilliz client."""
store: Zilliz store: Zilliz
"""The underlying Zilliz store."""
retriever: BaseRetriever retriever: BaseRetriever
"""The underlying retriever."""
@root_validator(pre=True) @root_validator(pre=True)
def create_client(cls, values: dict) -> dict: def create_client(cls, values: dict) -> dict:
@ -73,8 +78,10 @@ class ZillizRetriever(BaseRetriever):
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever: def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
""" """Deprecated ZillizRetreiver.
Deprecated ZillizRetreiver. Please use ZillizRetriever ('i' before 'e') instead.
Please use ZillizRetriever ('i' before 'e') instead.
Args: Args:
*args: *args:
**kwargs: **kwargs: