mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 12:18:24 +00:00
Add MMR methods to chroma (#2148)
Hi, I added MMR similar to faais and milvus to chroma. Please let me know what you think.
This commit is contained in:
parent
fc009f61c8
commit
4e9ee566ef
@ -5,9 +5,12 @@ import logging
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import chromadb
|
import chromadb
|
||||||
@ -182,6 +185,69 @@ class Chroma(VectorStore):
|
|||||||
|
|
||||||
return _results_to_docs_and_scores(results)
|
return _results_to_docs_and_scores(results)
|
||||||
|
|
||||||
|
def max_marginal_relevance_search_by_vector(
|
||||||
|
self,
|
||||||
|
embedding: List[float],
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
embedding: Embedding to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = self._collection.query(
|
||||||
|
query_embeddings=embedding,
|
||||||
|
n_results=fetch_k,
|
||||||
|
where=filter,
|
||||||
|
include=["metadatas", "documents", "distances", "embeddings"],
|
||||||
|
)
|
||||||
|
mmr_selected = maximal_marginal_relevance(
|
||||||
|
np.array(embedding, dtype=np.float32), results["embeddings"][0], k=k
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = _results_to_docs(results)
|
||||||
|
|
||||||
|
selected_results = [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||||
|
return selected_results
|
||||||
|
|
||||||
|
def max_marginal_relevance_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
k: int = 4,
|
||||||
|
fetch_k: int = 20,
|
||||||
|
filter: Optional[Dict[str, str]] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Return docs selected using the maximal marginal relevance.
|
||||||
|
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||||
|
among selected documents.
|
||||||
|
Args:
|
||||||
|
query: Text to look up documents similar to.
|
||||||
|
k: Number of Documents to return. Defaults to 4.
|
||||||
|
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||||
|
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||||
|
Returns:
|
||||||
|
List of Documents selected by maximal marginal relevance.
|
||||||
|
"""
|
||||||
|
if self._embedding_function is None:
|
||||||
|
raise ValueError(
|
||||||
|
"For MMR search, you must specify an embedding function on" "creation."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding = self._embedding_function.embed_query(query)
|
||||||
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
|
embedding, k, fetch_k, filter
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
def delete_collection(self) -> None:
|
def delete_collection(self) -> None:
|
||||||
"""Delete the collection."""
|
"""Delete the collection."""
|
||||||
self._client.delete_collection(self._collection.name)
|
self._client.delete_collection(self._collection.name)
|
||||||
|
Loading…
Reference in New Issue
Block a user