Add MMR methods to chroma (#2148)

Hi, I added MMR similar to faais and milvus to chroma. Please let me
know what you think.
This commit is contained in:
Arttii 2023-03-31 05:51:16 +02:00 committed by GitHub
parent fc009f61c8
commit 4e9ee566ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,9 +5,12 @@ import logging
import uuid import uuid
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
if TYPE_CHECKING: if TYPE_CHECKING:
import chromadb import chromadb
@ -182,6 +185,69 @@ class Chroma(VectorStore):
return _results_to_docs_and_scores(results) return _results_to_docs_and_scores(results)
def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
embedding: Embedding to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
results = self._collection.query(
query_embeddings=embedding,
n_results=fetch_k,
where=filter,
include=["metadatas", "documents", "distances", "embeddings"],
)
mmr_selected = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k
)
candidates = _results_to_docs(results)
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
return selected_results
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
filter: Optional[Dict[str, str]] = None,
) -> List[Document]:
"""Return docs selected using the maximal marginal relevance.
Maximal marginal relevance optimizes for similarity to query AND diversity
among selected documents.
Args:
query: Text to look up documents similar to.
k: Number of Documents to return. Defaults to 4.
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
Returns:
List of Documents selected by maximal marginal relevance.
"""
if self._embedding_function is None:
raise ValueError(
"For MMR search, you must specify an embedding function on" "creation."
)
embedding = self._embedding_function.embed_query(query)
docs = self.max_marginal_relevance_search_by_vector(
embedding, k, fetch_k, filter
)
return docs
def delete_collection(self) -> None: def delete_collection(self) -> None:
"""Delete the collection.""" """Delete the collection."""
self._client.delete_collection(self._collection.name) self._client.delete_collection(self._collection.name)