diff --git a/langchain/retrievers/arxiv.py b/langchain/retrievers/arxiv.py index 89a1370271c..2510aadd79a 100644 --- a/langchain/retrievers/arxiv.py +++ b/langchain/retrievers/arxiv.py @@ -10,7 +10,8 @@ from langchain.utilities.arxiv import ArxivAPIWrapper class ArxivRetriever(BaseRetriever, ArxivAPIWrapper): """ - It is effectively a wrapper for ArxivAPIWrapper. + Retriever for Arxiv. + It wraps load() to get_relevant_documents(). It uses all ArxivAPIWrapper arguments without any change. """ diff --git a/langchain/retrievers/azure_cognitive_search.py b/langchain/retrievers/azure_cognitive_search.py index 02c92b8c19e..214f663e95b 100644 --- a/langchain/retrievers/azure_cognitive_search.py +++ b/langchain/retrievers/azure_cognitive_search.py @@ -1,4 +1,4 @@ -"""Retriever wrapper for Azure Cognitive Search.""" +"""Retriever for the Azure Cognitive Search service.""" from __future__ import annotations @@ -18,7 +18,7 @@ from langchain.utils import get_from_dict_or_env class AzureCognitiveSearchRetriever(BaseRetriever): - """Wrapper around Azure Cognitive Search.""" + """Retriever for the Azure Cognitive Search service.""" service_name: str = "" """Name of Azure Cognitive Search service""" diff --git a/langchain/retrievers/bm25.py b/langchain/retrievers/bm25.py index 4487654140e..735ca9913a4 100644 --- a/langchain/retrievers/bm25.py +++ b/langchain/retrievers/bm25.py @@ -19,10 +19,16 @@ def default_preprocessing_func(text: str) -> List[str]: class BM25Retriever(BaseRetriever): + """BM25 Retriever without elastic search.""" + vectorizer: Any + """ BM25 vectorizer.""" docs: List[Document] + """ List of documents.""" k: int = 4 + """ Number of documents to return.""" preprocess_func: Callable[[str], List[str]] = default_preprocessing_func + """ Preprocessing function to use on the text before BM25 vectorization.""" class Config: """Configuration for this pydantic object.""" @@ -38,6 +44,18 @@ class BM25Retriever(BaseRetriever): preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: + """ + Create a BM25Retriever from a list of texts. + Args: + texts: A list of texts to vectorize. + metadatas: A list of metadata dicts to associate with each text. + bm25_params: Parameters to pass to the BM25 vectorizer. + preprocess_func: A function to preprocess each text before vectorization. + **kwargs: Any other arguments to pass to the retriever. + + Returns: + A BM25Retriever instance. + """ try: from rank_bm25 import BM25Okapi except ImportError: @@ -64,6 +82,17 @@ class BM25Retriever(BaseRetriever): preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, **kwargs: Any, ) -> BM25Retriever: + """ + Create a BM25Retriever from a list of Documents. + Args: + documents: A list of Documents to vectorize. + bm25_params: Parameters to pass to the BM25 vectorizer. + preprocess_func: A function to preprocess each text before vectorization. + **kwargs: Any other arguments to pass to the retriever. + + Returns: + A BM25Retriever instance. + """ texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents)) return cls.from_texts( texts=texts, diff --git a/langchain/retrievers/chaindesk.py b/langchain/retrievers/chaindesk.py index f2bf654de07..b68f31658e4 100644 --- a/langchain/retrievers/chaindesk.py +++ b/langchain/retrievers/chaindesk.py @@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document class ChaindeskRetriever(BaseRetriever): - """Retriever that uses the Chaindesk API.""" + """Retriever for the Chaindesk API.""" datastore_url: str top_k: Optional[int] diff --git a/langchain/retrievers/chatgpt_plugin_retriever.py b/langchain/retrievers/chatgpt_plugin_retriever.py index 4eba2a44591..5f2404f88ca 100644 --- a/langchain/retrievers/chatgpt_plugin_retriever.py +++ b/langchain/retrievers/chatgpt_plugin_retriever.py @@ -13,16 +13,24 @@ from langchain.schema import BaseRetriever, Document class ChatGPTPluginRetriever(BaseRetriever): + """Retrieves documents from a ChatGPT plugin.""" + url: str + """URL of the ChatGPT plugin.""" bearer_token: str + """Bearer token for the ChatGPT plugin.""" top_k: int = 3 + """Number of documents to return.""" filter: Optional[dict] = None + """Filter to apply to the results.""" aiosession: Optional[aiohttp.ClientSession] = None + """Aiohttp session to use for requests.""" class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True + """Allow arbitrary types.""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun diff --git a/langchain/retrievers/contextual_compression.py b/langchain/retrievers/contextual_compression.py index d3810893e58..0a5654b052b 100644 --- a/langchain/retrievers/contextual_compression.py +++ b/langchain/retrievers/contextual_compression.py @@ -1,5 +1,3 @@ -"""Retriever that wraps a base retriever and filters the results.""" - from typing import Any, List from langchain.callbacks.manager import ( diff --git a/langchain/retrievers/databerry.py b/langchain/retrievers/databerry.py index aa03d1cf463..d46144ac6f1 100644 --- a/langchain/retrievers/databerry.py +++ b/langchain/retrievers/databerry.py @@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document class DataberryRetriever(BaseRetriever): - """Retriever that uses the Databerry API.""" + """Retriever for the Databerry API.""" datastore_url: str top_k: Optional[int] diff --git a/langchain/retrievers/docarray.py b/langchain/retrievers/docarray.py index e50a654ad3f..0ba247bf0c5 100644 --- a/langchain/retrievers/docarray.py +++ b/langchain/retrievers/docarray.py @@ -21,7 +21,7 @@ class SearchType(str, Enum): class DocArrayRetriever(BaseRetriever): """ - Retriever class for DocArray Document Indices. + Retriever for DocArray Document Indices. Currently, supports 5 backends: InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex, diff --git a/langchain/retrievers/document_compressors/chain_extract.py b/langchain/retrievers/document_compressors/chain_extract.py index a50703f2cc2..872124dd60e 100644 --- a/langchain/retrievers/document_compressors/chain_extract.py +++ b/langchain/retrievers/document_compressors/chain_extract.py @@ -42,6 +42,9 @@ def _get_default_chain_prompt() -> PromptTemplate: class LLMChainExtractor(BaseDocumentCompressor): + """DocumentCompressor that uses an LLM chain to extract + the relevant parts of documents.""" + llm_chain: LLMChain """LLM wrapper to use for compressing documents.""" diff --git a/langchain/retrievers/document_compressors/chain_filter.py b/langchain/retrievers/document_compressors/chain_filter.py index 5b4aae949dc..68d6123a248 100644 --- a/langchain/retrievers/document_compressors/chain_filter.py +++ b/langchain/retrievers/document_compressors/chain_filter.py @@ -68,6 +68,16 @@ class LLMChainFilter(BaseDocumentCompressor): prompt: Optional[BasePromptTemplate] = None, **kwargs: Any ) -> "LLMChainFilter": + """Create a LLMChainFilter from a language model. + + Args: + llm: The language model to use for filtering. + prompt: The prompt to use for the filter. + **kwargs: Additional arguments to pass to the constructor. + + Returns: + A LLMChainFilter that uses the given language model. + """ _prompt = prompt if prompt is not None else _get_default_chain_prompt() llm_chain = LLMChain(llm=llm, prompt=_prompt) return cls(llm_chain=llm_chain, **kwargs) diff --git a/langchain/retrievers/document_compressors/cohere_rerank.py b/langchain/retrievers/document_compressors/cohere_rerank.py index dd63b842427..722cc6d33c3 100644 --- a/langchain/retrievers/document_compressors/cohere_rerank.py +++ b/langchain/retrievers/document_compressors/cohere_rerank.py @@ -21,9 +21,14 @@ else: class CohereRerank(BaseDocumentCompressor): + """DocumentCompressor that uses Cohere's rerank API to compress documents.""" + client: Client + """Cohere client to use for compressing documents.""" top_n: int = 3 + """Number of documents to return.""" model: str = "rerank-english-v2.0" + """Model to use for reranking.""" class Config: """Configuration for this pydantic object.""" @@ -54,6 +59,17 @@ class CohereRerank(BaseDocumentCompressor): query: str, callbacks: Optional[Callbacks] = None, ) -> Sequence[Document]: + """ + Compress documents using Cohere's rerank API. + + Args: + documents: A sequence of documents to compress. + query: The query to use for compressing the documents. + callbacks: Callbacks to run during the compression process. + + Returns: + A sequence of compressed documents. + """ if len(documents) == 0: # to avoid empty api call return [] doc_list = list(documents) diff --git a/langchain/retrievers/document_compressors/embeddings_filter.py b/langchain/retrievers/document_compressors/embeddings_filter.py index 692a34bcc78..589ff465696 100644 --- a/langchain/retrievers/document_compressors/embeddings_filter.py +++ b/langchain/retrievers/document_compressors/embeddings_filter.py @@ -1,4 +1,3 @@ -"""Document compressor that uses embeddings to drop documents unrelated to the query.""" from typing import Callable, Dict, Optional, Sequence import numpy as np @@ -18,6 +17,9 @@ from langchain.schema import Document class EmbeddingsFilter(BaseDocumentCompressor): + """Document compressor that uses embeddings to drop documents + unrelated to the query.""" + embeddings: Embeddings """Embeddings to use for embedding document contents and queries.""" similarity_fn: Callable = cosine_similarity diff --git a/langchain/retrievers/elastic_search_bm25.py b/langchain/retrievers/elastic_search_bm25.py index e1e09f55c1c..3a76b36ef0f 100644 --- a/langchain/retrievers/elastic_search_bm25.py +++ b/langchain/retrievers/elastic_search_bm25.py @@ -14,8 +14,7 @@ from langchain.schema import BaseRetriever class ElasticSearchBM25Retriever(BaseRetriever): - """Wrapper around Elasticsearch using BM25 as a retrieval method. - + """Retriever for the Elasticsearch using BM25 as a retrieval method. To connect to an Elasticsearch instance that requires login credentials, including Elastic Cloud, use the Elasticsearch URL format @@ -41,12 +40,26 @@ class ElasticSearchBM25Retriever(BaseRetriever): """ client: Any + """Elasticsearch client.""" index_name: str + """Name of the index to use in Elasticsearch.""" @classmethod def create( cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75 ) -> ElasticSearchBM25Retriever: + """ + Create a ElasticSearchBM25Retriever from a list of texts. + + Args: + elasticsearch_url: URL of the Elasticsearch instance to connect to. + index_name: Name of the index to use in Elasticsearch. + k1: BM25 parameter k1. + b: BM25 parameter b. + + Returns: + + """ from elasticsearch import Elasticsearch # Create an Elasticsearch client instance diff --git a/langchain/retrievers/kendra.py b/langchain/retrievers/kendra.py index 94b72759e4d..2ceeb1d0fcd 100644 --- a/langchain/retrievers/kendra.py +++ b/langchain/retrievers/kendra.py @@ -51,24 +51,38 @@ class Highlight(BaseModel, extra=Extra.allow): class TextWithHighLights(BaseModel, extra=Extra.allow): + """Text with highlights.""" + Text: str + """The text.""" Highlights: Optional[Any] + """The highlights.""" class AdditionalResultAttributeValue(BaseModel, extra=Extra.allow): + """The value of an additional result attribute.""" + TextWithHighlightsValue: TextWithHighLights + """The text with highlights value.""" class AdditionalResultAttribute(BaseModel, extra=Extra.allow): + """An additional result attribute.""" + Key: str + """The key of the attribute.""" ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"] + """The type of the value.""" Value: AdditionalResultAttributeValue + """The value of the attribute.""" def get_value_text(self) -> str: return self.Value.TextWithHighlightsValue.Text class QueryResultItem(BaseModel, extra=Extra.allow): + """A query result item.""" + DocumentId: str DocumentTitle: TextWithHighLights DocumentURI: Optional[str] @@ -111,9 +125,19 @@ class QueryResultItem(BaseModel, extra=Extra.allow): class QueryResult(BaseModel, extra=Extra.allow): + """A query result.""" + ResultItems: List[QueryResultItem] def get_top_k_docs(self, top_n: int) -> List[Document]: + """Gets the top k documents. + + Args: + top_n: The number of documents to return. + + Returns: + The top k documents. + """ items_len = len(self.ResultItems) count = items_len if items_len < top_n else top_n docs = [self.ResultItems[i].to_doc() for i in range(0, count)] @@ -122,24 +146,42 @@ class QueryResult(BaseModel, extra=Extra.allow): class DocumentAttributeValue(BaseModel, extra=Extra.allow): + """The value of a document attribute.""" + DateValue: Optional[str] + """The date value.""" LongValue: Optional[int] + """The long value.""" StringListValue: Optional[List[str]] + """The string list value.""" StringValue: Optional[str] + """The string value.""" class DocumentAttribute(BaseModel, extra=Extra.allow): + """A document attribute.""" + Key: str + """The key of the attribute.""" Value: DocumentAttributeValue + """The value of the attribute.""" class RetrieveResultItem(BaseModel, extra=Extra.allow): + """A retrieve result item.""" + Content: Optional[str] + """The content of the item.""" DocumentAttributes: Optional[List[DocumentAttribute]] = [] + """The document attributes.""" DocumentId: Optional[str] + """The document ID.""" DocumentTitle: Optional[str] + """The document title.""" DocumentURI: Optional[str] + """The document URI.""" Id: Optional[str] + """The ID of the item.""" def get_excerpt(self) -> str: if not self.Content: @@ -156,8 +198,12 @@ class RetrieveResultItem(BaseModel, extra=Extra.allow): class RetrieveResult(BaseModel, extra=Extra.allow): + """A retrieve result.""" + QueryId: str + """The ID of the query.""" ResultItems: List[RetrieveResultItem] + """The result items.""" def get_top_k_docs(self, top_n: int) -> List[Document]: items_len = len(self.ResultItems) @@ -168,7 +214,7 @@ class RetrieveResult(BaseModel, extra=Extra.allow): class AmazonKendraRetriever(BaseRetriever): - """Retriever class to query documents from Amazon Kendra Index. + """Retriever for the Amazon Kendra Index. Args: index_id: Kendra index id diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py index 945909d10b2..51a3effb221 100644 --- a/langchain/retrievers/knn.py +++ b/langchain/retrievers/knn.py @@ -36,10 +36,15 @@ class KNNRetriever(BaseRetriever): """KNN Retriever.""" embeddings: Embeddings + """Embeddings model to use.""" index: Any + """Index of embeddings.""" texts: List[str] + """List of texts to index.""" k: int = 4 + """Number of results to return.""" relevancy_threshold: Optional[float] = None + """Threshold for relevancy.""" class Config: diff --git a/langchain/retrievers/llama_index.py b/langchain/retrievers/llama_index.py index 8cce86418db..e393d121aa7 100644 --- a/langchain/retrievers/llama_index.py +++ b/langchain/retrievers/llama_index.py @@ -10,10 +10,13 @@ from langchain.schema import BaseRetriever, Document class LlamaIndexRetriever(BaseRetriever): - """Question-answering with sources over an LlamaIndex data structure.""" + """Retriever for the question-answering with sources over + an LlamaIndex data structure.""" index: Any + """LlamaIndex index to query.""" query_kwargs: Dict = Field(default_factory=dict) + """Keyword arguments to pass to the query method.""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun @@ -46,10 +49,13 @@ class LlamaIndexRetriever(BaseRetriever): class LlamaIndexGraphRetriever(BaseRetriever): - """Question-answering with sources over an LlamaIndex graph data structure.""" + """Retriever for question-answering with sources over an LlamaIndex + graph data structure.""" graph: Any + """LlamaIndex graph to query.""" query_configs: List[Dict] = Field(default_factory=list) + """List of query configs to pass to the query method.""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun diff --git a/langchain/retrievers/merger_retriever.py b/langchain/retrievers/merger_retriever.py index b1f8b329929..92af2c7ca68 100644 --- a/langchain/retrievers/merger_retriever.py +++ b/langchain/retrievers/merger_retriever.py @@ -8,14 +8,10 @@ from langchain.schema import BaseRetriever, Document class MergerRetriever(BaseRetriever): - """ - This class merges the results of multiple retrievers. - - Args: - retrievers: A list of retrievers to merge. - """ + """Retriever that merges the results of multiple retrievers.""" retrievers: List[BaseRetriever] + """A list of retrievers to merge.""" def _get_relevant_documents( self, diff --git a/langchain/retrievers/metal.py b/langchain/retrievers/metal.py index 80dd12cc31f..5271a897b54 100644 --- a/langchain/retrievers/metal.py +++ b/langchain/retrievers/metal.py @@ -13,8 +13,9 @@ class MetalRetriever(BaseRetriever): """Retriever that uses the Metal API.""" client: Any - + """The Metal client to use.""" params: Optional[dict] = None + """The parameters to pass to the Metal client.""" @root_validator(pre=True) def validate_client(cls, values: dict) -> dict: diff --git a/langchain/retrievers/multi_query.py b/langchain/retrievers/multi_query.py index 52fe6fb80f7..ca7d6eb3eb5 100644 --- a/langchain/retrievers/multi_query.py +++ b/langchain/retrievers/multi_query.py @@ -17,10 +17,15 @@ logger = logging.getLogger(__name__) class LineList(BaseModel): + """List of lines.""" + lines: List[str] = Field(description="Lines of text") + """List of lines.""" class LineListOutputParser(PydanticOutputParser): + """Output parser for a list of lines.""" + def __init__(self) -> None: super().__init__(pydantic_object=LineList) diff --git a/langchain/retrievers/pinecone_hybrid_search.py b/langchain/retrievers/pinecone_hybrid_search.py index 6e8541937dc..a6c998719b9 100644 --- a/langchain/retrievers/pinecone_hybrid_search.py +++ b/langchain/retrievers/pinecone_hybrid_search.py @@ -99,12 +99,19 @@ def create_index( class PineconeHybridSearchRetriever(BaseRetriever): + """Pinecone Hybrid Search Retriever.""" + embeddings: Embeddings + """Embeddings model to use.""" """description""" sparse_encoder: Any + """Sparse encoder to use.""" index: Any + """Pinecone index to use.""" top_k: int = 4 + """Number of documents to return.""" alpha: float = 0.5 + """Alpha value for hybrid search.""" class Config: """Configuration for this pydantic object.""" diff --git a/langchain/retrievers/pubmed.py b/langchain/retrievers/pubmed.py index 573a9f2c100..d49d581800a 100644 --- a/langchain/retrievers/pubmed.py +++ b/langchain/retrievers/pubmed.py @@ -1,4 +1,3 @@ -"""A retriever that uses PubMed API to retrieve documents.""" from typing import List from langchain.callbacks.manager import ( @@ -10,8 +9,8 @@ from langchain.utilities.pupmed import PubMedAPIWrapper class PubMedRetriever(BaseRetriever, PubMedAPIWrapper): - """ - It is effectively a wrapper for PubMedAPIWrapper. + """Retriever for PubMed API. + It wraps load() to get_relevant_documents(). It uses all PubMedAPIWrapper arguments without any change. """ diff --git a/langchain/retrievers/remote_retriever.py b/langchain/retrievers/remote_retriever.py index f0f7ba4dc89..ae17401974c 100644 --- a/langchain/retrievers/remote_retriever.py +++ b/langchain/retrievers/remote_retriever.py @@ -11,12 +11,20 @@ from langchain.schema import BaseRetriever, Document class RemoteLangChainRetriever(BaseRetriever): + """Retriever for remote LangChain API.""" + url: str + """URL of the remote LangChain API.""" headers: Optional[dict] = None + """Headers to use for the request.""" input_key: str = "message" + """Key to use for the input in the request.""" response_key: str = "response" + """Key to use for the response in the request.""" page_content_key: str = "page_content" + """Key to use for the page content in the response.""" metadata_key: str = "metadata" + """Key to use for the metadata in the response.""" def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun diff --git a/langchain/retrievers/self_query/base.py b/langchain/retrievers/self_query/base.py index 60e63087c60..653c314dcfc 100644 --- a/langchain/retrievers/self_query/base.py +++ b/langchain/retrievers/self_query/base.py @@ -52,7 +52,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor: class SelfQueryRetriever(BaseRetriever, BaseModel): - """Retriever that wraps around a vector store and uses an LLM to generate + """Retriever that uses a vector store and an LLM to generate the vector store queries.""" vectorstore: VectorStore diff --git a/langchain/retrievers/self_query/chroma.py b/langchain/retrievers/self_query/chroma.py index cbb707d1ac3..64c3fcbc45d 100644 --- a/langchain/retrievers/self_query/chroma.py +++ b/langchain/retrievers/self_query/chroma.py @@ -1,4 +1,3 @@ -"""Logic for converting internal query language to a valid Chroma query.""" from typing import Dict, Tuple, Union from langchain.chains.query_constructor.ir import ( @@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import ( class ChromaTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" + """Translate internal query language elements to valid filters.""" allowed_operators = [Operator.AND, Operator.OR] """Subset of allowed logical operators.""" diff --git a/langchain/retrievers/self_query/myscale.py b/langchain/retrievers/self_query/myscale.py index cb6d147f23c..e50af7a1293 100644 --- a/langchain/retrievers/self_query/myscale.py +++ b/langchain/retrievers/self_query/myscale.py @@ -48,7 +48,7 @@ def FUNCTION_COMPOSER(op_name: str) -> Callable: class MyScaleTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" + """Translate internal query language elements to valid filters.""" allowed_operators = [Operator.AND, Operator.OR, Operator.NOT] """Subset of allowed logical operators.""" diff --git a/langchain/retrievers/self_query/pinecone.py b/langchain/retrievers/self_query/pinecone.py index e7664ad26d7..733065937f7 100644 --- a/langchain/retrievers/self_query/pinecone.py +++ b/langchain/retrievers/self_query/pinecone.py @@ -1,4 +1,3 @@ -"""Logic for converting internal query language to a valid Pinecone query.""" from typing import Dict, Tuple, Union from langchain.chains.query_constructor.ir import ( @@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import ( class PineconeTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" + """Translate the internal query language elements to valid filters.""" allowed_comparators = ( Comparator.EQ, diff --git a/langchain/retrievers/self_query/qdrant.py b/langchain/retrievers/self_query/qdrant.py index 84e964bdf90..e421eef023e 100644 --- a/langchain/retrievers/self_query/qdrant.py +++ b/langchain/retrievers/self_query/qdrant.py @@ -1,4 +1,3 @@ -"""Logic for converting internal query language to a valid Qdrant query.""" from __future__ import annotations from typing import TYPE_CHECKING, Tuple @@ -17,7 +16,7 @@ if TYPE_CHECKING: class QdrantTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" + """Translate the internal query language elements to valid filters.""" allowed_comparators = ( Comparator.EQ, diff --git a/langchain/retrievers/self_query/weaviate.py b/langchain/retrievers/self_query/weaviate.py index af6a8acc3fe..cc4727c0951 100644 --- a/langchain/retrievers/self_query/weaviate.py +++ b/langchain/retrievers/self_query/weaviate.py @@ -1,4 +1,3 @@ -"""Logic for converting internal query language to a valid Weaviate query.""" from typing import Dict, Tuple, Union from langchain.chains.query_constructor.ir import ( @@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import ( class WeaviateTranslator(Visitor): - """Logic for converting internal query language elements to valid filters.""" + """Translate the internal query language elements to valid filters.""" allowed_operators = [Operator.AND, Operator.OR] """Subset of allowed logical operators.""" diff --git a/langchain/retrievers/svm.py b/langchain/retrievers/svm.py index f2e6a141dd9..96e34f160f5 100644 --- a/langchain/retrievers/svm.py +++ b/langchain/retrievers/svm.py @@ -1,7 +1,3 @@ -"""SMV Retriever. -Largely based on -https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb""" - from __future__ import annotations import concurrent.futures @@ -20,6 +16,7 @@ from langchain.schema import BaseRetriever, Document def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: """ Create an index of embeddings for a list of contexts. + Args: contexts: List of contexts to embed. embeddings: Embeddings model to use. @@ -32,13 +29,22 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: class SVMRetriever(BaseRetriever): - """SVM Retriever.""" + """SVM Retriever. + + Largely based on + https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb + """ embeddings: Embeddings + """Embeddings model to use.""" index: Any + """Index of embeddings.""" texts: List[str] + """List of texts to index.""" k: int = 4 + """Number of results to return.""" relevancy_threshold: Optional[float] = None + """Threshold for relevancy.""" class Config: diff --git a/langchain/retrievers/tfidf.py b/langchain/retrievers/tfidf.py index 5517de547ff..6be34df2606 100644 --- a/langchain/retrievers/tfidf.py +++ b/langchain/retrievers/tfidf.py @@ -1,8 +1,3 @@ -"""TF-IDF Retriever. - -Largely based on -https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb""" - from __future__ import annotations from typing import Any, Dict, Iterable, List, Optional @@ -15,10 +10,20 @@ from langchain.schema import BaseRetriever, Document class TFIDFRetriever(BaseRetriever): + """TF-IDF Retriever. + + Largely based on + https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb + """ + vectorizer: Any + """TF-IDF vectorizer.""" docs: List[Document] + """Documents.""" tfidf_array: Any + """TF-IDF array.""" k: int = 4 + """Number of documents to return.""" class Config: """Configuration for this pydantic object.""" diff --git a/langchain/retrievers/time_weighted_retriever.py b/langchain/retrievers/time_weighted_retriever.py index 64e641a6a50..dd767e53b7e 100644 --- a/langchain/retrievers/time_weighted_retriever.py +++ b/langchain/retrievers/time_weighted_retriever.py @@ -1,5 +1,3 @@ -"""Retriever that combines embedding similarity with recency in retrieving values.""" - import datetime from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple @@ -20,7 +18,8 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f class TimeWeightedVectorStoreRetriever(BaseRetriever): - """Retriever combining embedding similarity with recency.""" + """Retriever that combines embedding similarity with + recency in retrieving values.""" vectorstore: VectorStore """The vectorstore to store documents and determine salience.""" diff --git a/langchain/retrievers/vespa_retriever.py b/langchain/retrievers/vespa_retriever.py index 29bd420ced0..50e0396a5eb 100644 --- a/langchain/retrievers/vespa_retriever.py +++ b/langchain/retrievers/vespa_retriever.py @@ -1,5 +1,3 @@ -"""Wrapper for retrieving documents from Vespa.""" - from __future__ import annotations import json @@ -16,12 +14,16 @@ if TYPE_CHECKING: class VespaRetriever(BaseRetriever): - """Retriever that uses the Vespa.""" + """Retriever that uses Vespa.""" app: Vespa + """Vespa application to query.""" body: Dict + """Body of the query.""" content_field: str + """Name of the content field.""" metadata_fields: Sequence[str] + """Names of the metadata fields.""" def _query(self, body: Dict) -> List[Document]: response = self.app.query(body) @@ -97,6 +99,9 @@ class VespaRetriever(BaseRetriever): yql (Optional[str]): Full YQL query to be used. Should not be specified if _filter or sources are specified. Defaults to None. kwargs (Any): Keyword arguments added to query body. + + Returns: + VespaRetriever: Instantiated VespaRetriever. """ try: from vespa.application import Vespa diff --git a/langchain/retrievers/weaviate_hybrid_search.py b/langchain/retrievers/weaviate_hybrid_search.py index fe2601c835d..d6d00d5326c 100644 --- a/langchain/retrievers/weaviate_hybrid_search.py +++ b/langchain/retrievers/weaviate_hybrid_search.py @@ -1,5 +1,3 @@ -"""Wrapper around weaviate vector database.""" - from __future__ import annotations from typing import Any, Dict, List, Optional, cast @@ -16,7 +14,7 @@ from langchain.schema import BaseRetriever class WeaviateHybridSearchRetriever(BaseRetriever): - """Retriever that uses Weaviate's hybrid search to retrieve documents.""" + """Retriever for the Weaviate's hybrid search.""" client: Any """keyword arguments to pass to the Weaviate client.""" diff --git a/langchain/retrievers/wikipedia.py b/langchain/retrievers/wikipedia.py index fe43099cf98..d47775878ac 100644 --- a/langchain/retrievers/wikipedia.py +++ b/langchain/retrievers/wikipedia.py @@ -9,8 +9,8 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper): - """ - It is effectively a wrapper for WikipediaAPIWrapper. + """Retriever for Wikipedia API. + It wraps load() to get_relevant_documents(). It uses all WikipediaAPIWrapper arguments without any change. """ diff --git a/langchain/retrievers/zep.py b/langchain/retrievers/zep.py index 64333ff8dc8..7d1f18cda1a 100644 --- a/langchain/retrievers/zep.py +++ b/langchain/retrievers/zep.py @@ -15,8 +15,9 @@ if TYPE_CHECKING: class ZepRetriever(BaseRetriever): - """A Retriever implementation for the Zep long-term memory store. Search your - user's long-term chat history with Zep. + """Retriever for the Zep long-term memory store. + + Search your user's long-term chat history with Zep. Note: You will need to provide the user's `session_id` to use this retriever. @@ -30,10 +31,11 @@ class ZepRetriever(BaseRetriever): """ zep_client: Any - + """Zep client.""" session_id: str - + """Zep session ID.""" top_k: Optional[int] + """Number of documents to return.""" @root_validator(pre=True) def create_client(cls, values: dict) -> dict: diff --git a/langchain/retrievers/zilliz.py b/langchain/retrievers/zilliz.py index 8ff463d283d..e40366d161e 100644 --- a/langchain/retrievers/zilliz.py +++ b/langchain/retrievers/zilliz.py @@ -1,4 +1,3 @@ -"""Zilliz Retriever""" import warnings from typing import Any, Dict, List, Optional @@ -16,16 +15,22 @@ from langchain.vectorstores.zilliz import Zilliz class ZillizRetriever(BaseRetriever): - """Retriever that uses the Zilliz API.""" + """Retriever for the Zilliz API.""" embedding_function: Embeddings + """The underlying embedding function from which documents will be retrieved.""" collection_name: str = "LangChainCollection" + """The name of the collection in Zilliz.""" connection_args: Optional[Dict[str, Any]] = None + """The connection arguments for the Zilliz client.""" consistency_level: str = "Session" + """The consistency level for the Zilliz client.""" search_params: Optional[dict] = None - + """The search parameters for the Zilliz client.""" store: Zilliz + """The underlying Zilliz store.""" retriever: BaseRetriever + """The underlying retriever.""" @root_validator(pre=True) def create_client(cls, values: dict) -> dict: @@ -73,8 +78,10 @@ class ZillizRetriever(BaseRetriever): def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever: - """ - Deprecated ZillizRetreiver. Please use ZillizRetriever ('i' before 'e') instead. + """Deprecated ZillizRetreiver. + + Please use ZillizRetriever ('i' before 'e') instead. + Args: *args: **kwargs: