mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-09 23:12:38 +00:00
Implemented MMR search for PGVector (#10396)
Description: Implemented MMR search for PGVector. Issue: #7466 Dependencies: None Tag maintainer: Twitter handle: @JohnMai95
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -17,6 +19,7 @@ from typing import (
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import sqlalchemy
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
@@ -26,6 +29,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain.vectorstores._pgvector_data_models import CollectionStore
|
||||
@@ -54,6 +58,11 @@ class BaseModel(Base):
|
||||
uuid = sqlalchemy.Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
|
||||
def _results_to_docs(docs_and_scores: Any) -> List[Document]:
|
||||
"""Return docs from docs and scores."""
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
|
||||
class PGVector(VectorStore):
|
||||
"""`Postgres`/`PGVector` vector store.
|
||||
|
||||
@@ -339,7 +348,7 @@ class PGVector(VectorStore):
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List of Documents most similar to the query and score for each
|
||||
List of Documents most similar to the query and score for each.
|
||||
"""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
docs = self.similarity_search_with_score_by_vector(
|
||||
@@ -367,6 +376,31 @@ class PGVector(VectorStore):
|
||||
k: int = 4,
|
||||
filter: Optional[dict] = None,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
results = self.__query_collection(embedding=embedding, k=k, filter=filter)
|
||||
|
||||
return self._results_to_docs_and_scores(results)
|
||||
|
||||
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and scores from results."""
|
||||
docs = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.EmbeddingStore.document,
|
||||
metadata=result.EmbeddingStore.cmetadata,
|
||||
),
|
||||
result.distance if self.embedding_function is not None else None,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return docs
|
||||
|
||||
def __query_collection(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
) -> List[Any]:
|
||||
"""Query the collection."""
|
||||
with Session(self._conn) as session:
|
||||
collection = self.get_collection(session)
|
||||
if not collection:
|
||||
@@ -410,18 +444,7 @@ class PGVector(VectorStore):
|
||||
.limit(k)
|
||||
.all()
|
||||
)
|
||||
|
||||
docs = [
|
||||
(
|
||||
Document(
|
||||
page_content=result.EmbeddingStore.document,
|
||||
metadata=result.EmbeddingStore.cmetadata,
|
||||
),
|
||||
result.distance if self.embedding_function is not None else None,
|
||||
)
|
||||
for result in results
|
||||
]
|
||||
return docs
|
||||
return results
|
||||
|
||||
def similarity_search_by_vector(
|
||||
self,
|
||||
@@ -443,7 +466,7 @@ class PGVector(VectorStore):
|
||||
docs_and_scores = self.similarity_search_with_score_by_vector(
|
||||
embedding=embedding, k=k, filter=filter
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
return _results_to_docs(docs_and_scores)
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
@@ -640,3 +663,190 @@ class PGVector(VectorStore):
|
||||
f" for distance_strategy of {self._distance_strategy}."
|
||||
"Consider providing relevance_score_fn to PGVector constructor."
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs selected using the maximal marginal relevance with score
|
||||
to embedding vector.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding: Embedding to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
|
||||
relevance to the query and score for each.
|
||||
"""
|
||||
results = self.__query_collection(embedding=embedding, k=fetch_k, filter=filter)
|
||||
|
||||
embedding_list = [result.EmbeddingStore.embedding for result in results]
|
||||
|
||||
mmr_selected = maximal_marginal_relevance(
|
||||
np.array(embedding, dtype=np.float32),
|
||||
embedding_list,
|
||||
k=k,
|
||||
lambda_mult=lambda_mult,
|
||||
)
|
||||
|
||||
candidates = self._results_to_docs_and_scores(results)
|
||||
|
||||
return [r for i, r in enumerate(candidates) if i in mmr_selected]
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
return self.max_marginal_relevance_search_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_with_score(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs selected using the maximal marginal relevance with score.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query (str): Text to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Tuple[Document, float]]: List of Documents selected by maximal marginal
|
||||
relevance to the query and score for each.
|
||||
"""
|
||||
embedding = self.embedding_function.embed_query(query)
|
||||
docs = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding=embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return docs
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance
|
||||
to embedding vector.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
embedding (str): Text to look up documents similar to.
|
||||
k (int): Number of Documents to return. Defaults to 4.
|
||||
fetch_k (int): Number of Documents to fetch to pass to MMR algorithm.
|
||||
Defaults to 20.
|
||||
lambda_mult (float): Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Document]: List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
docs_and_scores = self.max_marginal_relevance_search_with_score_by_vector(
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return _results_to_docs(docs_and_scores)
|
||||
|
||||
async def amax_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
filter: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance."""
|
||||
|
||||
# This is a temporary workaround to make the similarity search
|
||||
# asynchronous. The proper solution is to make the similarity search
|
||||
# asynchronous in the vector store implementations.
|
||||
func = partial(
|
||||
self.max_marginal_relevance_search_by_vector,
|
||||
embedding,
|
||||
k=k,
|
||||
fetch_k=fetch_k,
|
||||
lambda_mult=lambda_mult,
|
||||
filter=filter,
|
||||
**kwargs,
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
||||
|
@@ -279,3 +279,31 @@ def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
||||
)
|
||||
output = retriever.get_relevant_documents("foo")
|
||||
assert output == []
|
||||
|
||||
|
||||
def test_pgvector_max_marginal_relevance_search() -> None:
|
||||
"""Test max marginal relevance search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search("foo", k=1, fetch_k=3)
|
||||
assert output == [Document(page_content="foo")]
|
||||
|
||||
|
||||
def test_pgvector_max_marginal_relevance_search_with_score() -> None:
|
||||
"""Test max marginal relevance search with relevance scores."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
output = docsearch.max_marginal_relevance_search_with_score("foo", k=1, fetch_k=3)
|
||||
assert output == [(Document(page_content="foo"), 0.0)]
|
||||
|
Reference in New Issue
Block a user