Compare commits

...

1 Commits

Author SHA1 Message Date
Bagatur
ecb86cb58d rfc 2023-10-12 14:38:12 -07:00

View File

@@ -13,12 +13,15 @@ from typing import (
ClassVar,
Collection,
Dict,
Iterable,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
TypeVar,
Union,
)
from langchain.pydantic_v1 import Field, root_validator
@@ -35,18 +38,22 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
VST = TypeVar("VST", bound="VectorStore")
ID_TYPE = TypeVar("ID_TYPE")
SEARCH_HIT = Tuple[Document, dict]
SEARCH_RESULTS = Tuple[List[SEARCH_HIT], dict]
class VectorStore(ABC):
class VectorStore(Generic[ID_TYPE], ABC):
"""Interface for vector store."""
@abstractmethod
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
texts: Sequence[str],
*,
metadatas: Optional[Sequence[dict]] = None,
**kwargs: Any,
) -> List[str]:
) -> List[ID_TYPE]:
"""Run more texts through the embeddings and add to the vectorstore.
Args:
@@ -59,14 +66,14 @@ class VectorStore(ABC):
"""
@property
def embeddings(self) -> Optional[Embeddings]:
def embeddings(self) -> Embeddings:
"""Access the query embedding object if available."""
logger.debug(
f"{Embeddings.__name__} is not implemented for {self.__class__.__name__}"
)
return None
raise NotImplementedError
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
def delete(self, ids: Optional[Sequence[ID_TYPE]] = None, **kwargs: Any) -> bool:
"""Delete by vector ID or other criteria.
Args:
@@ -74,16 +81,15 @@ class VectorStore(ABC):
**kwargs: Other keyword arguments that subclasses might use.
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
bool: True if deletion is successful, False otherwise.
"""
raise NotImplementedError("delete method must be implemented by subclass.")
async def adelete(
self, ids: Optional[List[str]] = None, **kwargs: Any
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
self, ids: Optional[Sequence[ID_TYPE]] = None, **kwargs: Any
) -> bool:
"""Async delete by vector ID or other criteria.
Args:
ids: List of ids to delete.
@@ -91,23 +97,26 @@ class VectorStore(ABC):
Returns:
Optional[bool]: True if deletion is successful,
False otherwise, None if not implemented.
False otherwise, None if not implemented.
"""
raise NotImplementedError("delete method must be implemented by subclass.")
async def aadd_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
texts: Sequence[str],
*,
metadatas: Optional[Sequence[dict]] = None,
**kwargs: Any,
) -> List[str]:
) -> List[ID_TYPE]:
"""Run more texts through the embeddings and add to the vectorstore."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(self.add_texts, **kwargs), texts, metadatas
)
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
def add_documents(
self, documents: Sequence[Document], **kwargs: Any
) -> List[ID_TYPE]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
@@ -119,11 +128,11 @@ class VectorStore(ABC):
# TODO: Handle the case where the user doesn't provide ids on the Collection
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return self.add_texts(texts, metadatas, **kwargs)
return self.add_texts(texts, metadatas=metadatas, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
self, documents: Sequence[Document], **kwargs: Any
) -> List[ID_TYPE]:
"""Run more documents through the embeddings and add to the vectorstore.
Args:
@@ -134,37 +143,55 @@ class VectorStore(ABC):
"""
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
return await self.aadd_texts(texts, metadatas, **kwargs)
return await self.aadd_texts(texts, metadata=metadatas, **kwargs)
def search(self, query: str, search_type: str, **kwargs: Any) -> List[Document]:
def search(
self, query: Union[str, List[float]], search_type: str, **kwargs: Any
) -> SEARCH_RESULTS:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return self.similarity_search(query, **kwargs)
if isinstance(query, str):
docs = self.similarity_search(query, **kwargs)
else:
docs = self.similarity_search_by_vector(query, **kwargs)
elif search_type == "mmr":
return self.max_marginal_relevance_search(query, **kwargs)
if isinstance(query, str):
docs = self.max_marginal_relevance_search(query, **kwargs)
else:
docs = self.max_marginal_relevance_search_by_vector(query, **kwargs)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
return [(d, {}) for d in docs], {}
async def asearch(
self, query: str, search_type: str, **kwargs: Any
) -> List[Document]:
) -> SEARCH_RESULTS:
"""Return docs most similar to query using specified search type."""
if search_type == "similarity":
return await self.asimilarity_search(query, **kwargs)
if isinstance(query, str):
docs = await self.asimilarity_search(query, **kwargs)
else:
docs = await self.asimilarity_search_by_vector(query, **kwargs)
elif search_type == "mmr":
return await self.amax_marginal_relevance_search(query, **kwargs)
if isinstance(query, str):
docs = await self.amax_marginal_relevance_search(query, **kwargs)
else:
docs = await self.amax_marginal_relevance_search_by_vector(
query, **kwargs
)
else:
raise ValueError(
f"search_type of {search_type} not allowed. Expected "
"search_type to be 'similarity' or 'mmr'."
)
return [(d, {}) for d in docs], {}
@abstractmethod
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
self, query: str, *, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
@@ -219,6 +246,7 @@ class VectorStore(ABC):
def _similarity_search_with_relevance_scores(
self,
query: str,
*,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -246,6 +274,7 @@ class VectorStore(ABC):
def similarity_search_with_relevance_scores(
self,
query: str,
*,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
@@ -291,7 +320,7 @@ class VectorStore(ABC):
return docs_and_similarities
async def asimilarity_search_with_relevance_scores(
self, query: str, k: int = 4, **kwargs: Any
self, query: str, *, k: int = 4, **kwargs: Any
) -> List[Tuple[Document, float]]:
"""Return docs most similar to query."""
@@ -304,7 +333,7 @@ class VectorStore(ABC):
return await asyncio.get_event_loop().run_in_executor(None, func)
async def asimilarity_search(
self, query: str, k: int = 4, **kwargs: Any
self, query: str, *, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to query."""
@@ -315,7 +344,7 @@ class VectorStore(ABC):
return await asyncio.get_event_loop().run_in_executor(None, func)
def similarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
self, embedding: List[float], *, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to embedding vector.
@@ -329,7 +358,7 @@ class VectorStore(ABC):
raise NotImplementedError
async def asimilarity_search_by_vector(
self, embedding: List[float], k: int = 4, **kwargs: Any
self, embedding: List[float], *, k: int = 4, **kwargs: Any
) -> List[Document]:
"""Return docs most similar to embedding vector."""
@@ -342,6 +371,7 @@ class VectorStore(ABC):
def max_marginal_relevance_search(
self,
query: str,
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
@@ -368,6 +398,7 @@ class VectorStore(ABC):
async def amax_marginal_relevance_search(
self,
query: str,
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
@@ -391,6 +422,7 @@ class VectorStore(ABC):
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
@@ -417,6 +449,7 @@ class VectorStore(ABC):
async def amax_marginal_relevance_search_by_vector(
self,
embedding: List[float],
*,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
@@ -428,34 +461,32 @@ class VectorStore(ABC):
@classmethod
def from_documents(
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
documents: Sequence[Document],
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return cls.from_texts(texts, embedding, metadatas=metadatas, **kwargs)
return cls.from_texts(texts, metadatas=metadatas, **kwargs)
@classmethod
async def afrom_documents(
cls: Type[VST],
documents: List[Document],
embedding: Embeddings,
documents: Sequence[Document],
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from documents and embeddings."""
texts = [d.page_content for d in documents]
metadatas = [d.metadata for d in documents]
return await cls.afrom_texts(texts, embedding, metadatas=metadatas, **kwargs)
return await cls.afrom_texts(texts, metadatas=metadatas, **kwargs)
@classmethod
@abstractmethod
def from_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
texts: Sequence[str],
*,
metadatas: Optional[Sequence[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
@@ -463,14 +494,14 @@ class VectorStore(ABC):
@classmethod
async def afrom_texts(
cls: Type[VST],
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
texts: Sequence[str],
*,
metadatas: Optional[Sequence[dict]] = None,
**kwargs: Any,
) -> VST:
"""Return VectorStore initialized from texts and embeddings."""
return await asyncio.get_running_loop().run_in_executor(
None, partial(cls.from_texts, **kwargs), texts, embedding, metadatas
None, partial(cls.from_texts, **kwargs), texts, metadatas
)
def _get_retriever_tags(self) -> List[str]:
@@ -480,7 +511,13 @@ class VectorStore(ABC):
tags.append(self.embeddings.__class__.__name__)
return tags
def as_retriever(self, **kwargs: Any) -> VectorStoreRetriever:
def as_retriever(
self,
search_type: str = "similarity",
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**search_kwargs: Any,
) -> VectorStoreRetriever:
"""Return VectorStoreRetriever initialized from this VectorStore.
Args:
@@ -534,10 +571,11 @@ class VectorStore(ABC):
search_kwargs={'filter': {'paper_title':'GPT-4 Technical Report'}}
)
"""
tags = kwargs.pop("tags", None) or []
tags = tags or []
tags.extend(self._get_retriever_tags())
return VectorStoreRetriever(vectorstore=self, **kwargs, tags=tags)
return VectorStoreRetriever(
vectorstore=self, search_kwargs=search_kwargs, tags=tags, metadata=metadata
)
class VectorStoreRetriever(BaseRetriever):
@@ -620,12 +658,12 @@ class VectorStoreRetriever(BaseRetriever):
raise ValueError(f"search_type of {self.search_type} not allowed.")
return docs
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
def add_documents(self, documents: Sequence[Document], **kwargs: Any) -> List[str]:
"""Add documents to vectorstore."""
return self.vectorstore.add_documents(documents, **kwargs)
async def aadd_documents(
self, documents: List[Document], **kwargs: Any
self, documents: Sequence[Document], **kwargs: Any
) -> List[str]:
"""Add documents to vectorstore."""
return await self.vectorstore.aadd_documents(documents, **kwargs)