mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 22:29:51 +00:00
docstrings retrievers
(#7858)
Added/updated docstrings `retrievers` @baskaryan
This commit is contained in:
parent
5b4d53e8ef
commit
74b701f42b
@ -10,7 +10,8 @@ from langchain.utilities.arxiv import ArxivAPIWrapper
|
||||
|
||||
class ArxivRetriever(BaseRetriever, ArxivAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for ArxivAPIWrapper.
|
||||
Retriever for Arxiv.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all ArxivAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
@ -1,4 +1,4 @@
|
||||
"""Retriever wrapper for Azure Cognitive Search."""
|
||||
"""Retriever for the Azure Cognitive Search service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@ -18,7 +18,7 @@ from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class AzureCognitiveSearchRetriever(BaseRetriever):
|
||||
"""Wrapper around Azure Cognitive Search."""
|
||||
"""Retriever for the Azure Cognitive Search service."""
|
||||
|
||||
service_name: str = ""
|
||||
"""Name of Azure Cognitive Search service"""
|
||||
|
@ -19,10 +19,16 @@ def default_preprocessing_func(text: str) -> List[str]:
|
||||
|
||||
|
||||
class BM25Retriever(BaseRetriever):
|
||||
"""BM25 Retriever without elastic search."""
|
||||
|
||||
vectorizer: Any
|
||||
""" BM25 vectorizer."""
|
||||
docs: List[Document]
|
||||
""" List of documents."""
|
||||
k: int = 4
|
||||
""" Number of documents to return."""
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func
|
||||
""" Preprocessing function to use on the text before BM25 vectorization."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -38,6 +44,18 @@ class BM25Retriever(BaseRetriever):
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
|
||||
**kwargs: Any,
|
||||
) -> BM25Retriever:
|
||||
"""
|
||||
Create a BM25Retriever from a list of texts.
|
||||
Args:
|
||||
texts: A list of texts to vectorize.
|
||||
metadatas: A list of metadata dicts to associate with each text.
|
||||
bm25_params: Parameters to pass to the BM25 vectorizer.
|
||||
preprocess_func: A function to preprocess each text before vectorization.
|
||||
**kwargs: Any other arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
A BM25Retriever instance.
|
||||
"""
|
||||
try:
|
||||
from rank_bm25 import BM25Okapi
|
||||
except ImportError:
|
||||
@ -64,6 +82,17 @@ class BM25Retriever(BaseRetriever):
|
||||
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
|
||||
**kwargs: Any,
|
||||
) -> BM25Retriever:
|
||||
"""
|
||||
Create a BM25Retriever from a list of Documents.
|
||||
Args:
|
||||
documents: A list of Documents to vectorize.
|
||||
bm25_params: Parameters to pass to the BM25 vectorizer.
|
||||
preprocess_func: A function to preprocess each text before vectorization.
|
||||
**kwargs: Any other arguments to pass to the retriever.
|
||||
|
||||
Returns:
|
||||
A BM25Retriever instance.
|
||||
"""
|
||||
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
|
||||
return cls.from_texts(
|
||||
texts=texts,
|
||||
|
@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class ChaindeskRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Chaindesk API."""
|
||||
"""Retriever for the Chaindesk API."""
|
||||
|
||||
datastore_url: str
|
||||
top_k: Optional[int]
|
||||
|
@ -13,16 +13,24 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class ChatGPTPluginRetriever(BaseRetriever):
|
||||
"""Retrieves documents from a ChatGPT plugin."""
|
||||
|
||||
url: str
|
||||
"""URL of the ChatGPT plugin."""
|
||||
bearer_token: str
|
||||
"""Bearer token for the ChatGPT plugin."""
|
||||
top_k: int = 3
|
||||
"""Number of documents to return."""
|
||||
filter: Optional[dict] = None
|
||||
"""Filter to apply to the results."""
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
"""Aiohttp session to use for requests."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
arbitrary_types_allowed = True
|
||||
"""Allow arbitrary types."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Retriever that wraps a base retriever and filters the results."""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
|
@ -11,7 +11,7 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class DataberryRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Databerry API."""
|
||||
"""Retriever for the Databerry API."""
|
||||
|
||||
datastore_url: str
|
||||
top_k: Optional[int]
|
||||
|
@ -21,7 +21,7 @@ class SearchType(str, Enum):
|
||||
|
||||
class DocArrayRetriever(BaseRetriever):
|
||||
"""
|
||||
Retriever class for DocArray Document Indices.
|
||||
Retriever for DocArray Document Indices.
|
||||
|
||||
Currently, supports 5 backends:
|
||||
InMemoryExactNNIndex, HnswDocumentIndex, QdrantDocumentIndex,
|
||||
|
@ -42,6 +42,9 @@ def _get_default_chain_prompt() -> PromptTemplate:
|
||||
|
||||
|
||||
class LLMChainExtractor(BaseDocumentCompressor):
|
||||
"""DocumentCompressor that uses an LLM chain to extract
|
||||
the relevant parts of documents."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use for compressing documents."""
|
||||
|
||||
|
@ -68,6 +68,16 @@ class LLMChainFilter(BaseDocumentCompressor):
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
**kwargs: Any
|
||||
) -> "LLMChainFilter":
|
||||
"""Create a LLMChainFilter from a language model.
|
||||
|
||||
Args:
|
||||
llm: The language model to use for filtering.
|
||||
prompt: The prompt to use for the filter.
|
||||
**kwargs: Additional arguments to pass to the constructor.
|
||||
|
||||
Returns:
|
||||
A LLMChainFilter that uses the given language model.
|
||||
"""
|
||||
_prompt = prompt if prompt is not None else _get_default_chain_prompt()
|
||||
llm_chain = LLMChain(llm=llm, prompt=_prompt)
|
||||
return cls(llm_chain=llm_chain, **kwargs)
|
||||
|
@ -21,9 +21,14 @@ else:
|
||||
|
||||
|
||||
class CohereRerank(BaseDocumentCompressor):
|
||||
"""DocumentCompressor that uses Cohere's rerank API to compress documents."""
|
||||
|
||||
client: Client
|
||||
"""Cohere client to use for compressing documents."""
|
||||
top_n: int = 3
|
||||
"""Number of documents to return."""
|
||||
model: str = "rerank-english-v2.0"
|
||||
"""Model to use for reranking."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
@ -54,6 +59,17 @@ class CohereRerank(BaseDocumentCompressor):
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using Cohere's rerank API.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
doc_list = list(documents)
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Document compressor that uses embeddings to drop documents unrelated to the query."""
|
||||
from typing import Callable, Dict, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
@ -18,6 +17,9 @@ from langchain.schema import Document
|
||||
|
||||
|
||||
class EmbeddingsFilter(BaseDocumentCompressor):
|
||||
"""Document compressor that uses embeddings to drop documents
|
||||
unrelated to the query."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings to use for embedding document contents and queries."""
|
||||
similarity_fn: Callable = cosine_similarity
|
||||
|
@ -14,8 +14,7 @@ from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
"""Wrapper around Elasticsearch using BM25 as a retrieval method.
|
||||
|
||||
"""Retriever for the Elasticsearch using BM25 as a retrieval method.
|
||||
|
||||
To connect to an Elasticsearch instance that requires login credentials,
|
||||
including Elastic Cloud, use the Elasticsearch URL format
|
||||
@ -41,12 +40,26 @@ class ElasticSearchBM25Retriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
client: Any
|
||||
"""Elasticsearch client."""
|
||||
index_name: str
|
||||
"""Name of the index to use in Elasticsearch."""
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls, elasticsearch_url: str, index_name: str, k1: float = 2.0, b: float = 0.75
|
||||
) -> ElasticSearchBM25Retriever:
|
||||
"""
|
||||
Create a ElasticSearchBM25Retriever from a list of texts.
|
||||
|
||||
Args:
|
||||
elasticsearch_url: URL of the Elasticsearch instance to connect to.
|
||||
index_name: Name of the index to use in Elasticsearch.
|
||||
k1: BM25 parameter k1.
|
||||
b: BM25 parameter b.
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
# Create an Elasticsearch client instance
|
||||
|
@ -51,24 +51,38 @@ class Highlight(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class TextWithHighLights(BaseModel, extra=Extra.allow):
|
||||
"""Text with highlights."""
|
||||
|
||||
Text: str
|
||||
"""The text."""
|
||||
Highlights: Optional[Any]
|
||||
"""The highlights."""
|
||||
|
||||
|
||||
class AdditionalResultAttributeValue(BaseModel, extra=Extra.allow):
|
||||
"""The value of an additional result attribute."""
|
||||
|
||||
TextWithHighlightsValue: TextWithHighLights
|
||||
"""The text with highlights value."""
|
||||
|
||||
|
||||
class AdditionalResultAttribute(BaseModel, extra=Extra.allow):
|
||||
"""An additional result attribute."""
|
||||
|
||||
Key: str
|
||||
"""The key of the attribute."""
|
||||
ValueType: Literal["TEXT_WITH_HIGHLIGHTS_VALUE"]
|
||||
"""The type of the value."""
|
||||
Value: AdditionalResultAttributeValue
|
||||
"""The value of the attribute."""
|
||||
|
||||
def get_value_text(self) -> str:
|
||||
return self.Value.TextWithHighlightsValue.Text
|
||||
|
||||
|
||||
class QueryResultItem(BaseModel, extra=Extra.allow):
|
||||
"""A query result item."""
|
||||
|
||||
DocumentId: str
|
||||
DocumentTitle: TextWithHighLights
|
||||
DocumentURI: Optional[str]
|
||||
@ -111,9 +125,19 @@ class QueryResultItem(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class QueryResult(BaseModel, extra=Extra.allow):
|
||||
"""A query result."""
|
||||
|
||||
ResultItems: List[QueryResultItem]
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
"""Gets the top k documents.
|
||||
|
||||
Args:
|
||||
top_n: The number of documents to return.
|
||||
|
||||
Returns:
|
||||
The top k documents.
|
||||
"""
|
||||
items_len = len(self.ResultItems)
|
||||
count = items_len if items_len < top_n else top_n
|
||||
docs = [self.ResultItems[i].to_doc() for i in range(0, count)]
|
||||
@ -122,24 +146,42 @@ class QueryResult(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class DocumentAttributeValue(BaseModel, extra=Extra.allow):
|
||||
"""The value of a document attribute."""
|
||||
|
||||
DateValue: Optional[str]
|
||||
"""The date value."""
|
||||
LongValue: Optional[int]
|
||||
"""The long value."""
|
||||
StringListValue: Optional[List[str]]
|
||||
"""The string list value."""
|
||||
StringValue: Optional[str]
|
||||
"""The string value."""
|
||||
|
||||
|
||||
class DocumentAttribute(BaseModel, extra=Extra.allow):
|
||||
"""A document attribute."""
|
||||
|
||||
Key: str
|
||||
"""The key of the attribute."""
|
||||
Value: DocumentAttributeValue
|
||||
"""The value of the attribute."""
|
||||
|
||||
|
||||
class RetrieveResultItem(BaseModel, extra=Extra.allow):
|
||||
"""A retrieve result item."""
|
||||
|
||||
Content: Optional[str]
|
||||
"""The content of the item."""
|
||||
DocumentAttributes: Optional[List[DocumentAttribute]] = []
|
||||
"""The document attributes."""
|
||||
DocumentId: Optional[str]
|
||||
"""The document ID."""
|
||||
DocumentTitle: Optional[str]
|
||||
"""The document title."""
|
||||
DocumentURI: Optional[str]
|
||||
"""The document URI."""
|
||||
Id: Optional[str]
|
||||
"""The ID of the item."""
|
||||
|
||||
def get_excerpt(self) -> str:
|
||||
if not self.Content:
|
||||
@ -156,8 +198,12 @@ class RetrieveResultItem(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class RetrieveResult(BaseModel, extra=Extra.allow):
|
||||
"""A retrieve result."""
|
||||
|
||||
QueryId: str
|
||||
"""The ID of the query."""
|
||||
ResultItems: List[RetrieveResultItem]
|
||||
"""The result items."""
|
||||
|
||||
def get_top_k_docs(self, top_n: int) -> List[Document]:
|
||||
items_len = len(self.ResultItems)
|
||||
@ -168,7 +214,7 @@ class RetrieveResult(BaseModel, extra=Extra.allow):
|
||||
|
||||
|
||||
class AmazonKendraRetriever(BaseRetriever):
|
||||
"""Retriever class to query documents from Amazon Kendra Index.
|
||||
"""Retriever for the Amazon Kendra Index.
|
||||
|
||||
Args:
|
||||
index_id: Kendra index id
|
||||
|
@ -36,10 +36,15 @@ class KNNRetriever(BaseRetriever):
|
||||
"""KNN Retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
index: Any
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
"""Threshold for relevancy."""
|
||||
|
||||
class Config:
|
||||
|
||||
|
@ -10,10 +10,13 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class LlamaIndexRetriever(BaseRetriever):
|
||||
"""Question-answering with sources over an LlamaIndex data structure."""
|
||||
"""Retriever for the question-answering with sources over
|
||||
an LlamaIndex data structure."""
|
||||
|
||||
index: Any
|
||||
"""LlamaIndex index to query."""
|
||||
query_kwargs: Dict = Field(default_factory=dict)
|
||||
"""Keyword arguments to pass to the query method."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
@ -46,10 +49,13 @@ class LlamaIndexRetriever(BaseRetriever):
|
||||
|
||||
|
||||
class LlamaIndexGraphRetriever(BaseRetriever):
|
||||
"""Question-answering with sources over an LlamaIndex graph data structure."""
|
||||
"""Retriever for question-answering with sources over an LlamaIndex
|
||||
graph data structure."""
|
||||
|
||||
graph: Any
|
||||
"""LlamaIndex graph to query."""
|
||||
query_configs: List[Dict] = Field(default_factory=list)
|
||||
"""List of query configs to pass to the query method."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
|
@ -8,14 +8,10 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class MergerRetriever(BaseRetriever):
|
||||
"""
|
||||
This class merges the results of multiple retrievers.
|
||||
|
||||
Args:
|
||||
retrievers: A list of retrievers to merge.
|
||||
"""
|
||||
"""Retriever that merges the results of multiple retrievers."""
|
||||
|
||||
retrievers: List[BaseRetriever]
|
||||
"""A list of retrievers to merge."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self,
|
||||
|
@ -13,8 +13,9 @@ class MetalRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Metal API."""
|
||||
|
||||
client: Any
|
||||
|
||||
"""The Metal client to use."""
|
||||
params: Optional[dict] = None
|
||||
"""The parameters to pass to the Metal client."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_client(cls, values: dict) -> dict:
|
||||
|
@ -17,10 +17,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LineList(BaseModel):
|
||||
"""List of lines."""
|
||||
|
||||
lines: List[str] = Field(description="Lines of text")
|
||||
"""List of lines."""
|
||||
|
||||
|
||||
class LineListOutputParser(PydanticOutputParser):
|
||||
"""Output parser for a list of lines."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(pydantic_object=LineList)
|
||||
|
||||
|
@ -99,12 +99,19 @@ def create_index(
|
||||
|
||||
|
||||
class PineconeHybridSearchRetriever(BaseRetriever):
|
||||
"""Pinecone Hybrid Search Retriever."""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
"""description"""
|
||||
sparse_encoder: Any
|
||||
"""Sparse encoder to use."""
|
||||
index: Any
|
||||
"""Pinecone index to use."""
|
||||
top_k: int = 4
|
||||
"""Number of documents to return."""
|
||||
alpha: float = 0.5
|
||||
"""Alpha value for hybrid search."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""A retriever that uses PubMed API to retrieve documents."""
|
||||
from typing import List
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
@ -10,8 +9,8 @@ from langchain.utilities.pupmed import PubMedAPIWrapper
|
||||
|
||||
|
||||
class PubMedRetriever(BaseRetriever, PubMedAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for PubMedAPIWrapper.
|
||||
"""Retriever for PubMed API.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all PubMedAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
@ -11,12 +11,20 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class RemoteLangChainRetriever(BaseRetriever):
|
||||
"""Retriever for remote LangChain API."""
|
||||
|
||||
url: str
|
||||
"""URL of the remote LangChain API."""
|
||||
headers: Optional[dict] = None
|
||||
"""Headers to use for the request."""
|
||||
input_key: str = "message"
|
||||
"""Key to use for the input in the request."""
|
||||
response_key: str = "response"
|
||||
"""Key to use for the response in the request."""
|
||||
page_content_key: str = "page_content"
|
||||
"""Key to use for the page content in the response."""
|
||||
metadata_key: str = "metadata"
|
||||
"""Key to use for the metadata in the response."""
|
||||
|
||||
def _get_relevant_documents(
|
||||
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
|
||||
|
@ -52,7 +52,7 @@ def _get_builtin_translator(vectorstore: VectorStore) -> Visitor:
|
||||
|
||||
|
||||
class SelfQueryRetriever(BaseRetriever, BaseModel):
|
||||
"""Retriever that wraps around a vector store and uses an LLM to generate
|
||||
"""Retriever that uses a vector store and an LLM to generate
|
||||
the vector store queries."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Logic for converting internal query language to a valid Chroma query."""
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
|
||||
|
||||
|
||||
class ChromaTranslator(Visitor):
|
||||
"""Logic for converting internal query language elements to valid filters."""
|
||||
"""Translate internal query language elements to valid filters."""
|
||||
|
||||
allowed_operators = [Operator.AND, Operator.OR]
|
||||
"""Subset of allowed logical operators."""
|
||||
|
@ -48,7 +48,7 @@ def FUNCTION_COMPOSER(op_name: str) -> Callable:
|
||||
|
||||
|
||||
class MyScaleTranslator(Visitor):
|
||||
"""Logic for converting internal query language elements to valid filters."""
|
||||
"""Translate internal query language elements to valid filters."""
|
||||
|
||||
allowed_operators = [Operator.AND, Operator.OR, Operator.NOT]
|
||||
"""Subset of allowed logical operators."""
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Logic for converting internal query language to a valid Pinecone query."""
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
|
||||
|
||||
|
||||
class PineconeTranslator(Visitor):
|
||||
"""Logic for converting internal query language elements to valid filters."""
|
||||
"""Translate the internal query language elements to valid filters."""
|
||||
|
||||
allowed_comparators = (
|
||||
Comparator.EQ,
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Logic for converting internal query language to a valid Qdrant query."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
@ -17,7 +16,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class QdrantTranslator(Visitor):
|
||||
"""Logic for converting internal query language elements to valid filters."""
|
||||
"""Translate the internal query language elements to valid filters."""
|
||||
|
||||
allowed_comparators = (
|
||||
Comparator.EQ,
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Logic for converting internal query language to a valid Weaviate query."""
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
from langchain.chains.query_constructor.ir import (
|
||||
@ -12,7 +11,7 @@ from langchain.chains.query_constructor.ir import (
|
||||
|
||||
|
||||
class WeaviateTranslator(Visitor):
|
||||
"""Logic for converting internal query language elements to valid filters."""
|
||||
"""Translate the internal query language elements to valid filters."""
|
||||
|
||||
allowed_operators = [Operator.AND, Operator.OR]
|
||||
"""Subset of allowed logical operators."""
|
||||
|
@ -1,7 +1,3 @@
|
||||
"""SMV Retriever.
|
||||
Largely based on
|
||||
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import concurrent.futures
|
||||
@ -20,6 +16,7 @@ from langchain.schema import BaseRetriever, Document
|
||||
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
"""
|
||||
Create an index of embeddings for a list of contexts.
|
||||
|
||||
Args:
|
||||
contexts: List of contexts to embed.
|
||||
embeddings: Embeddings model to use.
|
||||
@ -32,13 +29,22 @@ def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
||||
|
||||
|
||||
class SVMRetriever(BaseRetriever):
|
||||
"""SVM Retriever."""
|
||||
"""SVM Retriever.
|
||||
|
||||
Largely based on
|
||||
https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb
|
||||
"""
|
||||
|
||||
embeddings: Embeddings
|
||||
"""Embeddings model to use."""
|
||||
index: Any
|
||||
"""Index of embeddings."""
|
||||
texts: List[str]
|
||||
"""List of texts to index."""
|
||||
k: int = 4
|
||||
"""Number of results to return."""
|
||||
relevancy_threshold: Optional[float] = None
|
||||
"""Threshold for relevancy."""
|
||||
|
||||
class Config:
|
||||
|
||||
|
@ -1,8 +1,3 @@
|
||||
"""TF-IDF Retriever.
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
@ -15,10 +10,20 @@ from langchain.schema import BaseRetriever, Document
|
||||
|
||||
|
||||
class TFIDFRetriever(BaseRetriever):
|
||||
"""TF-IDF Retriever.
|
||||
|
||||
Largely based on
|
||||
https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb
|
||||
"""
|
||||
|
||||
vectorizer: Any
|
||||
"""TF-IDF vectorizer."""
|
||||
docs: List[Document]
|
||||
"""Documents."""
|
||||
tfidf_array: Any
|
||||
"""TF-IDF array."""
|
||||
k: int = 4
|
||||
"""Number of documents to return."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Retriever that combines embedding similarity with recency in retrieving values."""
|
||||
|
||||
import datetime
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
@ -20,7 +18,8 @@ def _get_hours_passed(time: datetime.datetime, ref_time: datetime.datetime) -> f
|
||||
|
||||
|
||||
class TimeWeightedVectorStoreRetriever(BaseRetriever):
|
||||
"""Retriever combining embedding similarity with recency."""
|
||||
"""Retriever that combines embedding similarity with
|
||||
recency in retrieving values."""
|
||||
|
||||
vectorstore: VectorStore
|
||||
"""The vectorstore to store documents and determine salience."""
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Wrapper for retrieving documents from Vespa."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@ -16,12 +14,16 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class VespaRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Vespa."""
|
||||
"""Retriever that uses Vespa."""
|
||||
|
||||
app: Vespa
|
||||
"""Vespa application to query."""
|
||||
body: Dict
|
||||
"""Body of the query."""
|
||||
content_field: str
|
||||
"""Name of the content field."""
|
||||
metadata_fields: Sequence[str]
|
||||
"""Names of the metadata fields."""
|
||||
|
||||
def _query(self, body: Dict) -> List[Document]:
|
||||
response = self.app.query(body)
|
||||
@ -97,6 +99,9 @@ class VespaRetriever(BaseRetriever):
|
||||
yql (Optional[str]): Full YQL query to be used. Should not be specified
|
||||
if _filter or sources are specified. Defaults to None.
|
||||
kwargs (Any): Keyword arguments added to query body.
|
||||
|
||||
Returns:
|
||||
VespaRetriever: Instantiated VespaRetriever.
|
||||
"""
|
||||
try:
|
||||
from vespa.application import Vespa
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Wrapper around weaviate vector database."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
@ -16,7 +14,7 @@ from langchain.schema import BaseRetriever
|
||||
|
||||
|
||||
class WeaviateHybridSearchRetriever(BaseRetriever):
|
||||
"""Retriever that uses Weaviate's hybrid search to retrieve documents."""
|
||||
"""Retriever for the Weaviate's hybrid search."""
|
||||
|
||||
client: Any
|
||||
"""keyword arguments to pass to the Weaviate client."""
|
||||
|
@ -9,8 +9,8 @@ from langchain.utilities.wikipedia import WikipediaAPIWrapper
|
||||
|
||||
|
||||
class WikipediaRetriever(BaseRetriever, WikipediaAPIWrapper):
|
||||
"""
|
||||
It is effectively a wrapper for WikipediaAPIWrapper.
|
||||
"""Retriever for Wikipedia API.
|
||||
|
||||
It wraps load() to get_relevant_documents().
|
||||
It uses all WikipediaAPIWrapper arguments without any change.
|
||||
"""
|
||||
|
@ -15,8 +15,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ZepRetriever(BaseRetriever):
|
||||
"""A Retriever implementation for the Zep long-term memory store. Search your
|
||||
user's long-term chat history with Zep.
|
||||
"""Retriever for the Zep long-term memory store.
|
||||
|
||||
Search your user's long-term chat history with Zep.
|
||||
|
||||
Note: You will need to provide the user's `session_id` to use this retriever.
|
||||
|
||||
@ -30,10 +31,11 @@ class ZepRetriever(BaseRetriever):
|
||||
"""
|
||||
|
||||
zep_client: Any
|
||||
|
||||
"""Zep client."""
|
||||
session_id: str
|
||||
|
||||
"""Zep session ID."""
|
||||
top_k: Optional[int]
|
||||
"""Number of documents to return."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def create_client(cls, values: dict) -> dict:
|
||||
|
@ -1,4 +1,3 @@
|
||||
"""Zilliz Retriever"""
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
@ -16,16 +15,22 @@ from langchain.vectorstores.zilliz import Zilliz
|
||||
|
||||
|
||||
class ZillizRetriever(BaseRetriever):
|
||||
"""Retriever that uses the Zilliz API."""
|
||||
"""Retriever for the Zilliz API."""
|
||||
|
||||
embedding_function: Embeddings
|
||||
"""The underlying embedding function from which documents will be retrieved."""
|
||||
collection_name: str = "LangChainCollection"
|
||||
"""The name of the collection in Zilliz."""
|
||||
connection_args: Optional[Dict[str, Any]] = None
|
||||
"""The connection arguments for the Zilliz client."""
|
||||
consistency_level: str = "Session"
|
||||
"""The consistency level for the Zilliz client."""
|
||||
search_params: Optional[dict] = None
|
||||
|
||||
"""The search parameters for the Zilliz client."""
|
||||
store: Zilliz
|
||||
"""The underlying Zilliz store."""
|
||||
retriever: BaseRetriever
|
||||
"""The underlying retriever."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def create_client(cls, values: dict) -> dict:
|
||||
@ -73,8 +78,10 @@ class ZillizRetriever(BaseRetriever):
|
||||
|
||||
|
||||
def ZillizRetreiver(*args: Any, **kwargs: Any) -> ZillizRetriever:
|
||||
"""
|
||||
Deprecated ZillizRetreiver. Please use ZillizRetriever ('i' before 'e') instead.
|
||||
"""Deprecated ZillizRetreiver.
|
||||
|
||||
Please use ZillizRetriever ('i' before 'e') instead.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs:
|
||||
|
Loading…
Reference in New Issue
Block a user