mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community[patch]: mmr search for Rockset vectorstore integration (#16908)
- **Description:** Adding support for mmr search in the Rockset vectorstore integration. - **Issue:** N/A - **Dependencies:** N/A - **Twitter handle:** `@_morgan_adams_` --------- Co-authored-by: Rockset API Bot <admin@rockset.io> Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
This commit is contained in:
parent
f51e6a35ba
commit
074ad5095f
@ -5,11 +5,14 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.runnables import run_in_executor
|
||||
from langchain_core.vectorstores import VectorStore
|
||||
|
||||
from langchain_community.vectorstores.utils import maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -254,7 +257,12 @@ class Rockset(VectorStore):
|
||||
"""Accepts a query_embedding (vector), and returns documents with
|
||||
similar embeddings along with their relevance scores."""
|
||||
|
||||
q_str = self._build_query_sql(embedding, distance_func, k, where_str)
|
||||
exclude_embeddings = True
|
||||
if "exclude_embeddings" in kwargs:
|
||||
exclude_embeddings = kwargs["exclude_embeddings"]
|
||||
q_str = self._build_query_sql(
|
||||
embedding, distance_func, k, where_str, exclude_embeddings
|
||||
)
|
||||
try:
|
||||
query_response = self._client.Queries.query(sql={"query": q_str})
|
||||
except Exception as e:
|
||||
@ -290,6 +298,60 @@ class Rockset(VectorStore):
|
||||
)
|
||||
return finalResult
|
||||
|
||||
def max_marginal_relevance_search(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
fetch_k: int = 20,
|
||||
lambda_mult: float = 0.5,
|
||||
*,
|
||||
where_str: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> List[Document]:
|
||||
"""Return docs selected using the maximal marginal relevance.
|
||||
|
||||
Maximal marginal relevance optimizes for similarity to query AND diversity
|
||||
among selected documents.
|
||||
|
||||
Args:
|
||||
query: Text to look up documents similar to.
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
fetch_k: Number of Documents to fetch to pass to MMR algorithm.
|
||||
distance_func (DistanceFunction): how to compute distance between two
|
||||
vectors in Rockset.
|
||||
lambda_mult: Number between 0 and 1 that determines the degree
|
||||
of diversity among the results with 0 corresponding
|
||||
to maximum diversity and 1 to minimum diversity.
|
||||
Defaults to 0.5.
|
||||
where_str: where clause for the sql query
|
||||
Returns:
|
||||
List of Documents selected by maximal marginal relevance.
|
||||
"""
|
||||
query_embedding = self._embeddings.embed_query(query)
|
||||
initial_docs = self.similarity_search_by_vector(
|
||||
query_embedding,
|
||||
k=fetch_k,
|
||||
where_str=where_str,
|
||||
exclude_embeddings=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
embeddings = [doc.metadata[self._embedding_key] for doc in initial_docs]
|
||||
|
||||
selected_indices = maximal_marginal_relevance(
|
||||
np.array(query_embedding),
|
||||
embeddings,
|
||||
lambda_mult=lambda_mult,
|
||||
k=k,
|
||||
)
|
||||
|
||||
# remove embeddings key before returning for cleanup to be consistent with
|
||||
# other search functions
|
||||
for i in selected_indices:
|
||||
del initial_docs[i].metadata[self._embedding_key]
|
||||
|
||||
return [initial_docs[i] for i in selected_indices]
|
||||
|
||||
# Helper functions
|
||||
|
||||
def _build_query_sql(
|
||||
@ -298,6 +360,7 @@ class Rockset(VectorStore):
|
||||
distance_func: DistanceFunction,
|
||||
k: int = 4,
|
||||
where_str: Optional[str] = None,
|
||||
exclude_embeddings: bool = True,
|
||||
) -> str:
|
||||
"""Builds Rockset SQL query to query similar vectors to query_vector"""
|
||||
|
||||
@ -305,8 +368,11 @@ class Rockset(VectorStore):
|
||||
distance_str = f"""{distance_func.value}({self._embedding_key}, \
|
||||
[{q_embedding_str}]) as dist"""
|
||||
where_str = f"WHERE {where_str}\n" if where_str else ""
|
||||
select_embedding = (
|
||||
f" EXCEPT({self._embedding_key})," if exclude_embeddings else ","
|
||||
)
|
||||
return f"""\
|
||||
SELECT * EXCEPT({self._embedding_key}), {distance_str}
|
||||
SELECT *{select_embedding} {distance_str}
|
||||
FROM {self._workspace}.{self._collection_name}
|
||||
{where_str}\
|
||||
ORDER BY dist {distance_func.order_by()}
|
||||
|
@ -96,16 +96,18 @@ class TestRockset:
|
||||
client, embeddings, COLLECTION_NAME, TEXT_KEY, EMBEDDING_KEY, WORKSPACE
|
||||
)
|
||||
|
||||
def test_rockset_insert_and_search(self) -> None:
|
||||
"""Test end to end vector search in Rockset"""
|
||||
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"metadata_index": i} for i in range(len(texts))]
|
||||
ids = self.rockset_vectorstore.add_texts(
|
||||
ids = cls.rockset_vectorstore.add_texts(
|
||||
texts=texts,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
|
||||
assert len(ids) == len(texts)
|
||||
|
||||
def test_rockset_search(self) -> None:
|
||||
"""Test end-to-end vector search in Rockset"""
|
||||
|
||||
# Test that `foo` is closest to `foo`
|
||||
output = self.rockset_vectorstore.similarity_search(
|
||||
query="foo", distance_func=Rockset.DistanceFunction.COSINE_SIM, k=1
|
||||
@ -121,6 +123,26 @@ class TestRockset:
|
||||
)
|
||||
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||
|
||||
def test_rockset_mmr_search(self) -> None:
|
||||
"""Test end-to-end mmr search in Rockset"""
|
||||
output = self.rockset_vectorstore.max_marginal_relevance_search(
|
||||
query="foo",
|
||||
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
||||
fetch_k=1,
|
||||
k=1,
|
||||
)
|
||||
assert output == [Document(page_content="foo", metadata={"metadata_index": 0})]
|
||||
|
||||
# Find closest vector to `foo` which is not `foo`
|
||||
output = self.rockset_vectorstore.max_marginal_relevance_search(
|
||||
query="foo",
|
||||
distance_func=Rockset.DistanceFunction.COSINE_SIM,
|
||||
fetch_k=3,
|
||||
k=1,
|
||||
where_str="metadata_index != 0",
|
||||
)
|
||||
assert output == [Document(page_content="bar", metadata={"metadata_index": 1})]
|
||||
|
||||
def test_add_documents_and_delete(self) -> None:
|
||||
""" "add_documents" and "delete" are requirements to support use
|
||||
with RecordManager"""
|
||||
@ -184,5 +206,21 @@ FROM {WORKSPACE}.{COLLECTION_NAME}
|
||||
WHERE age >= 10
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
"""
|
||||
assert q_str == expected
|
||||
|
||||
def test_build_query_sql_with_select_embeddings(self) -> None:
|
||||
vector = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
|
||||
q_str = self.rockset_vectorstore._build_query_sql(
|
||||
vector, Rockset.DistanceFunction.COSINE_SIM, 4, "age >= 10", False
|
||||
)
|
||||
vector_str = ",".join(map(str, vector))
|
||||
expected = f"""\
|
||||
SELECT *, \
|
||||
COSINE_SIM({EMBEDDING_KEY}, [{vector_str}]) as dist
|
||||
FROM {WORKSPACE}.{COLLECTION_NAME}
|
||||
WHERE age >= 10
|
||||
ORDER BY dist DESC
|
||||
LIMIT 4
|
||||
"""
|
||||
assert q_str == expected
|
||||
|
Loading…
Reference in New Issue
Block a user