zep: VectorStore: Use Native MMR (#12690)

- refactor to use Zep's native MMR; update example
- 
@baskaryan @eyurtsev
This commit is contained in:
Daniel Chalef
2023-11-02 16:45:42 -07:00
committed by GitHub
parent cc3d3920e3
commit 0cbdba6a9b
2 changed files with 185 additions and 149 deletions

View File

@@ -5,12 +5,9 @@ import warnings
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain.docstore.document import Document
from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING:
from zep_python.document import Document as ZepDocument
@@ -112,8 +109,7 @@ class ZepVectorStore(VectorStore):
collection = self._client.document.get_collection(self.collection_name)
except NotFoundError:
logger.info(
f"Collection {self.collection_name} not found. "
"Creating new collection."
f"Collection {self.collection_name} not found. Creating new collection."
)
collection = self._create_collection()
@@ -452,23 +448,6 @@ class ZepVectorStore(VectorStore):
for doc in results
]
def _max_marginal_relevance_selection(
self,
query_vector: List[float],
results: List["ZepDocument"],
*,
lambda_mult: float = 0.5,
k: int = 4,
) -> List[Document]:
mmr_selected = maximal_marginal_relevance(
np.array([query_vector], dtype=np.float32),
[d.embedding for d in results],
lambda_mult=lambda_mult,
k=k,
)
selected = [results[i] for i in mmr_selected]
return [Document(page_content=d.content, metadata=d.metadata) for d in selected]
def max_marginal_relevance_search(
self,
query: str,
@@ -487,6 +466,8 @@ class ZepVectorStore(VectorStore):
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Zep determines this automatically and this parameter is
ignored.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
@@ -504,16 +485,24 @@ class ZepVectorStore(VectorStore):
if not self._collection.is_auto_embedded and self._embedding:
query_vector = self._embedding.embed_query(query)
results = self._collection.search(
embedding=query_vector, limit=k, metadata=metadata, **kwargs
embedding=query_vector,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
else:
results, query_vector = self._collection.search_return_query_vector(
query, limit=k, metadata=metadata, **kwargs
query,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
return self._max_marginal_relevance_selection(
query_vector, results, k=k, lambda_mult=lambda_mult
)
return [Document(page_content=d.content, metadata=d.metadata) for d in results]
async def amax_marginal_relevance_search(
self,
@@ -534,16 +523,24 @@ class ZepVectorStore(VectorStore):
if not self._collection.is_auto_embedded and self._embedding:
query_vector = self._embedding.embed_query(query)
results = await self._collection.asearch(
embedding=query_vector, limit=k, metadata=metadata, **kwargs
embedding=query_vector,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
else:
results, query_vector = await self._collection.asearch_return_query_vector(
query, limit=k, metadata=metadata, **kwargs
query,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
return self._max_marginal_relevance_selection(
query_vector, results, k=k, lambda_mult=lambda_mult
)
return [Document(page_content=d.content, metadata=d.metadata) for d in results]
def max_marginal_relevance_search_by_vector(
self,
@@ -563,6 +560,8 @@ class ZepVectorStore(VectorStore):
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
Zep determines this automatically and this parameter is
ignored.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
@@ -577,12 +576,15 @@ class ZepVectorStore(VectorStore):
)
results = self._collection.search(
embedding=embedding, limit=k, metadata=metadata, **kwargs
embedding=embedding,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
return self._max_marginal_relevance_selection(
embedding, results, k=k, lambda_mult=lambda_mult
)
return [Document(page_content=d.content, metadata=d.metadata) for d in results]
async def amax_marginal_relevance_search_by_vector(
self,
@@ -600,12 +602,15 @@ class ZepVectorStore(VectorStore):
)
results = await self._collection.asearch(
embedding=embedding, limit=k, metadata=metadata, **kwargs
embedding=embedding,
limit=k,
metadata=metadata,
search_type="mmr",
mmr_lambda=lambda_mult,
**kwargs,
)
return self._max_marginal_relevance_selection(
embedding, results, k=k, lambda_mult=lambda_mult
)
return [Document(page_content=d.content, metadata=d.metadata) for d in results]
@classmethod
def from_texts(