mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
Refactor vector storage to correctly handle relevancy scores (#6570)
Description: This pull request aims to support generating the correct generic relevancy scores for different vector stores by refactoring the relevance score functions and their selection in the base class and subclasses of VectorStore. This is especially relevant with VectorStores that require a distance metric upon initialization. Note many of the current implenetations of `_similarity_search_with_relevance_scores` are not technically correct, as they just return `self.similarity_search_with_score(query, k, **kwargs)` without applying the relevant score function Also includes changes associated with: https://github.com/hwchase17/langchain/pull/6564 and https://github.com/hwchase17/langchain/pull/6494 See more indepth discussion in thread in #6494 Issue: https://github.com/hwchase17/langchain/issues/6526 https://github.com/hwchase17/langchain/issues/6481 https://github.com/hwchase17/langchain/issues/6346 Dependencies: None The changes include: - Properly handling score thresholding in FAISS `similarity_search_with_score_by_vector` for the corresponding distance metric. - Refactoring the `_similarity_search_with_relevance_scores` method in the base class and removing it from the subclasses for incorrectly implemented subclasses. - Adding a `_select_relevance_score_fn` method in the base class and implementing it in the subclasses to select the appropriate relevance score function based on the distance strategy. - Updating the `__init__` methods of the subclasses to set the `relevance_score_fn` attribute. - Removing the `_default_relevance_score_fn` function from the FAISS class and using the base class's `_euclidean_relevance_score_fn` instead. - Adding the `DistanceStrategy` enum to the `utils.py` file and updating the imports in the vector store classes. - Updating the tests to import the `DistanceStrategy` enum from the `utils.py` file. --------- Co-authored-by: Hanit <37485638+hanit-com@users.noreply.github.com>
This commit is contained in:
parent
bd0c6381f5
commit
5171c3bcca
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Type
|
||||
|
||||
from sqlalchemy import REAL, Column, String, Table, create_engine, insert, text
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSON, TEXT
|
||||
@ -79,6 +79,9 @@ class AnalyticDB(VectorStore):
|
||||
self.engine = create_engine(self.connection_string, **_engine_args)
|
||||
self.create_collection()
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._euclidean_relevance_score_fn
|
||||
|
||||
def create_table_if_not_exists(self) -> None:
|
||||
# Define the dynamic table
|
||||
Table(
|
||||
@ -242,28 +245,6 @@ class AnalyticDB(VectorStore):
|
||||
)
|
||||
return docs
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
|
||||
def similarity_search_with_score_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
|
@ -3,11 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import partial
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Dict,
|
||||
@ -137,6 +139,81 @@ class VectorStore(ABC):
|
||||
) -> List[Document]:
|
||||
"""Return docs most similar to query."""
|
||||
|
||||
@staticmethod
|
||||
def _euclidean_relevance_score_fn(distance: float) -> float:
|
||||
"""Return a similarity score on a scale [0, 1]."""
|
||||
# The 'correct' relevance function
|
||||
# may differ depending on a few things, including:
|
||||
# - the distance / similarity metric used by the VectorStore
|
||||
# - the scale of your embeddings (OpenAI's are unit normed. Many
|
||||
# others are not!)
|
||||
# - embedding dimensionality
|
||||
# - etc.
|
||||
# This function converts the euclidean norm of normalized embeddings
|
||||
# (0 is most similar, sqrt(2) most dissimilar)
|
||||
# to a similarity function (0 to 1)
|
||||
return 1.0 - distance / math.sqrt(2)
|
||||
|
||||
@staticmethod
|
||||
def _cosine_relevance_score_fn(distance: float) -> float:
|
||||
"""Normalize the distance to a score on a scale [0, 1]."""
|
||||
|
||||
return 1.0 - distance
|
||||
|
||||
@staticmethod
|
||||
def _max_inner_product_relevance_score_fn(distance: float) -> float:
|
||||
"""Normalize the distance to a score on a scale [0, 1]."""
|
||||
if distance > 0:
|
||||
return 1.0 - distance
|
||||
|
||||
return -1.0 * distance
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
|
||||
Vectorstores should define their own selection based method of relevance.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def similarity_search_with_score(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Run similarity search with distance."""
|
||||
raise NotImplementedError
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
Default similarity search with relevance scores. Modify if necessary
|
||||
in subclass.
|
||||
Return docs and relevance scores in the range [0, 1].
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
|
||||
Args:
|
||||
query: input text
|
||||
k: Number of Documents to return. Defaults to 4.
|
||||
**kwargs: kwargs to be passed to similarity search. Should include:
|
||||
score_threshold: Optional, a floating point value between 0 to 1 to
|
||||
filter the resulting set of retrieved docs
|
||||
|
||||
Returns:
|
||||
List of Tuples of (doc, similarity_score)
|
||||
"""
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
docs_and_scores = self.similarity_search_with_score(query, k, **kwargs)
|
||||
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
||||
def similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
@ -183,18 +260,6 @@ class VectorStore(ABC):
|
||||
)
|
||||
return docs_and_similarities
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def asimilarity_search_with_relevance_scores(
|
||||
self, query: str, k: int = 4, **kwargs: Any
|
||||
) -> List[Tuple[Document, float]]:
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import uuid
|
||||
from typing import Any, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -77,6 +77,9 @@ class Cassandra(VectorStore):
|
||||
primary_key_type="TEXT",
|
||||
)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._cosine_relevance_score_fn
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
"""
|
||||
Just an alias for `clear`
|
||||
@ -268,21 +271,6 @@ class Cassandra(VectorStore):
|
||||
k,
|
||||
)
|
||||
|
||||
# Even though this is a `_`-method,
|
||||
# it is apparently used by VectorSearch parent class
|
||||
# in an exposed method (`similarity_search_with_relevance_scores`).
|
||||
# So we implement it (hmm).
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
return self.similarity_search_with_score(
|
||||
query,
|
||||
k,
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
embedding: List[float],
|
||||
|
@ -3,7 +3,17 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -64,6 +74,7 @@ class Chroma(VectorStore):
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
) -> None:
|
||||
"""Initialize with Chroma client."""
|
||||
try:
|
||||
@ -100,6 +111,7 @@ class Chroma(VectorStore):
|
||||
else None,
|
||||
metadata=collection_metadata,
|
||||
)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
|
||||
@xor_args(("query_texts", "query_embeddings"))
|
||||
def __query_collection(
|
||||
@ -250,13 +262,37 @@ class Chroma(VectorStore):
|
||||
|
||||
return _results_to_docs_and_scores(results)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
distance = "l2"
|
||||
distance_key = "hnsw:space"
|
||||
metadata = self._collection.metadata
|
||||
|
||||
if metadata and distance_key in metadata:
|
||||
distance = metadata[distance_key]
|
||||
|
||||
if distance == "cosine":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif distance == "l2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
elif distance == "ip":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"No supported normalization function"
|
||||
f" for distance metric of type: {distance}."
|
||||
"Consider providing relevance_score_fn to Chroma constructor."
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
@ -428,6 +464,7 @@ class Chroma(VectorStore):
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.Client] = None,
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a raw documents.
|
||||
@ -443,6 +480,8 @@ class Chroma(VectorStore):
|
||||
metadatas (Optional[List[dict]]): List of metadatas. Defaults to None.
|
||||
ids (Optional[List[str]]): List of document IDs. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
collection_metadata (Optional[Dict]): Collection configurations.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Chroma: Chroma vectorstore.
|
||||
@ -453,6 +492,8 @@ class Chroma(VectorStore):
|
||||
persist_directory=persist_directory,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
chroma_collection.add_texts(texts=texts, metadatas=metadatas, ids=ids)
|
||||
return chroma_collection
|
||||
@ -467,6 +508,7 @@ class Chroma(VectorStore):
|
||||
persist_directory: Optional[str] = None,
|
||||
client_settings: Optional[chromadb.config.Settings] = None,
|
||||
client: Optional[chromadb.Client] = None, # Add this line
|
||||
collection_metadata: Optional[Dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Chroma:
|
||||
"""Create a Chroma vectorstore from a list of documents.
|
||||
@ -481,6 +523,9 @@ class Chroma(VectorStore):
|
||||
documents (List[Document]): List of documents to add to the vectorstore.
|
||||
embedding (Optional[Embeddings]): Embedding function. Defaults to None.
|
||||
client_settings (Optional[chromadb.config.Settings]): Chroma client settings
|
||||
collection_metadata (Optional[Dict]): Collection configurations.
|
||||
Defaults to None.
|
||||
|
||||
Returns:
|
||||
Chroma: Chroma vectorstore.
|
||||
"""
|
||||
@ -495,6 +540,8 @@ class Chroma(VectorStore):
|
||||
persist_directory=persist_directory,
|
||||
client_settings=client_settings,
|
||||
client=client,
|
||||
collection_metadata=collection_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None:
|
||||
|
@ -1,10 +1,11 @@
|
||||
"""Wrapper around FAISS vector database."""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
import uuid
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
@ -15,7 +16,7 @@ from langchain.docstore.document import Document
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
|
||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
@ -45,20 +46,6 @@ def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||
return faiss
|
||||
|
||||
|
||||
def _default_relevance_score_fn(score: float) -> float:
|
||||
"""Return a similarity score on a scale [0, 1]."""
|
||||
# The 'correct' relevance function
|
||||
# may differ depending on a few things, including:
|
||||
# - the distance / similarity metric used by the VectorStore
|
||||
# - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
# - embedding dimensionality
|
||||
# - etc.
|
||||
# This function converts the euclidean norm of normalized embeddings
|
||||
# (0 is most similar, sqrt(2) most dissimilar)
|
||||
# to a similarity function (0 to 1)
|
||||
return 1.0 - score / math.sqrt(2)
|
||||
|
||||
|
||||
class FAISS(VectorStore):
|
||||
"""Wrapper around FAISS vector database.
|
||||
|
||||
@ -78,16 +65,27 @@ class FAISS(VectorStore):
|
||||
index: Any,
|
||||
docstore: Docstore,
|
||||
index_to_docstore_id: Dict[int, str],
|
||||
relevance_score_fn: Callable[[float], float] = _default_relevance_score_fn,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
normalize_L2: bool = False,
|
||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
self.embedding_function = embedding_function
|
||||
self.index = index
|
||||
self.docstore = docstore
|
||||
self.index_to_docstore_id = index_to_docstore_id
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
self.distance_strategy = distance_strategy
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self._normalize_L2 = normalize_L2
|
||||
if (
|
||||
self.distance_strategy != DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
and self._normalize_L2
|
||||
):
|
||||
warnings.warn(
|
||||
"Normalizing L2 is not applicable for metric type: {strategy}".format(
|
||||
strategy=self.distance_strategy
|
||||
)
|
||||
)
|
||||
|
||||
def __add(
|
||||
self,
|
||||
@ -227,10 +225,16 @@ class FAISS(VectorStore):
|
||||
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.ge
|
||||
if self.distance_strategy
|
||||
in (DistanceStrategy.MAX_INNER_PRODUCT, DistanceStrategy.JACCARD)
|
||||
else operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if similarity >= score_threshold
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
return docs[:k]
|
||||
|
||||
@ -498,9 +502,16 @@ class FAISS(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> FAISS:
|
||||
faiss = dependable_faiss_import()
|
||||
distance_strategy = kwargs.get(
|
||||
"distance_strategy", DistanceStrategy.EUCLIDEAN_DISTANCE
|
||||
)
|
||||
if distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
index = faiss.IndexFlatIP(len(embeddings[0]))
|
||||
else:
|
||||
# Default to L2, currently other metric types not initialized.
|
||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||
vector = np.array(embeddings, dtype=np.float32)
|
||||
if normalize_L2:
|
||||
if normalize_L2 and distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
faiss.normalize_L2(vector)
|
||||
index.add(vector)
|
||||
documents = []
|
||||
@ -646,6 +657,31 @@ class FAISS(VectorStore):
|
||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
||||
)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided in
|
||||
# vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
# Default behavior is to use euclidean distance relevancy
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product,"
|
||||
" or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
@ -655,6 +691,12 @@ class FAISS(VectorStore):
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and their similarity scores on a scale from 0 to 1."""
|
||||
relevance_score_fn = self._select_relevance_score_fn()
|
||||
if relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"normalize_score_fn must be provided to"
|
||||
" FAISS constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(
|
||||
query,
|
||||
k=k,
|
||||
@ -662,4 +704,4 @@ class FAISS(VectorStore):
|
||||
fetch_k=fetch_k,
|
||||
**kwargs,
|
||||
)
|
||||
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
return [(doc, relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import enum
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.dialects.postgresql import JSON, UUID
|
||||
@ -121,6 +121,7 @@ class PGVector(VectorStore):
|
||||
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
||||
pre_delete_collection: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
) -> None:
|
||||
self.connection_string = connection_string
|
||||
self.embedding_function = embedding_function
|
||||
@ -129,6 +130,7 @@ class PGVector(VectorStore):
|
||||
self._distance_strategy = distance_strategy
|
||||
self.pre_delete_collection = pre_delete_collection
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.override_relevance_score_fn = relevance_score_fn
|
||||
self.__post_init__()
|
||||
|
||||
def __post_init__(
|
||||
@ -201,6 +203,11 @@ class PGVector(VectorStore):
|
||||
pre_delete_collection: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> PGVector:
|
||||
if ids is None:
|
||||
ids = [str(uuid.uuid1()) for _ in texts]
|
||||
|
||||
if not metadatas:
|
||||
metadatas = [{} for _ in texts]
|
||||
connection_string = cls.get_connection_string(kwargs)
|
||||
|
||||
store = cls(
|
||||
@ -209,6 +216,7 @@ class PGVector(VectorStore):
|
||||
embedding_function=embedding,
|
||||
distance_strategy=distance_strategy,
|
||||
pre_delete_collection=pre_delete_collection,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
store.add_embeddings(
|
||||
@ -590,3 +598,30 @@ class PGVector(VectorStore):
|
||||
) -> str:
|
||||
"""Return connection string from database parameters."""
|
||||
return f"postgresql+{driver}://{user}:{password}@{host}:{port}/{database}"
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
if self.override_relevance_score_fn is not None:
|
||||
return self.override_relevance_score_fn
|
||||
|
||||
# Default strategy is to rely on distance strategy provided
|
||||
# in vectorstore constructor
|
||||
if self.distance_strategy == DistanceStrategy.COSINE:
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN:
|
||||
return self._euclidean_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"No supported normalization function"
|
||||
f" for distance_strategy of {self.distance_strategy}."
|
||||
"Consider providing relevance_score_fn to PGVector constructor."
|
||||
)
|
||||
|
@ -10,7 +10,7 @@ import numpy as np
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -40,6 +40,7 @@ class Pinecone(VectorStore):
|
||||
index: Any,
|
||||
embedding_function: Callable,
|
||||
text_key: str,
|
||||
distance_strategy: Optional[DistanceStrategy] = DistanceStrategy.COSINE,
|
||||
):
|
||||
"""Initialize with Pinecone client."""
|
||||
try:
|
||||
@ -57,6 +58,7 @@ class Pinecone(VectorStore):
|
||||
self._index = index
|
||||
self._embedding_function = embedding_function
|
||||
self._text_key = text_key
|
||||
self.distance_strategy = distance_strategy
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
@ -147,14 +149,27 @@ class Pinecone(VectorStore):
|
||||
)
|
||||
return [doc for doc, _ in docs_and_scores]
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
kwargs.pop("score_threshold", None)
|
||||
return self.similarity_search_with_score(query, k, **kwargs)
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
|
||||
if self.distance_strategy == DistanceStrategy.COSINE:
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == DistanceStrategy.EUCLIDEAN_DISTANCE:
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, max_inner_product "
|
||||
"(dot product), or euclidean"
|
||||
)
|
||||
|
||||
def max_marginal_relevance_search_by_vector(
|
||||
self,
|
||||
|
@ -61,6 +61,7 @@ class Qdrant(VectorStore):
|
||||
embeddings: Optional[Embeddings] = None,
|
||||
content_payload_key: str = CONTENT_KEY,
|
||||
metadata_payload_key: str = METADATA_KEY,
|
||||
distance_strategy: str = "COSINE",
|
||||
vector_name: Optional[str] = VECTOR_NAME,
|
||||
embedding_function: Optional[Callable] = None, # deprecated
|
||||
):
|
||||
@ -112,6 +113,8 @@ class Qdrant(VectorStore):
|
||||
self._embeddings_function = embeddings
|
||||
self.embeddings = None
|
||||
|
||||
self.distance_strategy = distance_strategy.upper()
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -419,6 +422,28 @@ class Qdrant(VectorStore):
|
||||
for result in results
|
||||
]
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
"""
|
||||
The 'correct' relevance function
|
||||
may differ depending on a few things, including:
|
||||
- the distance / similarity metric used by the VectorStore
|
||||
- the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
|
||||
- embedding dimensionality
|
||||
- etc.
|
||||
"""
|
||||
|
||||
if self.distance_strategy == "COSINE":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_strategy == "DOT":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_strategy == "EUCLID":
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unknown distance strategy, must be cosine, "
|
||||
"max_inner_product, or euclidean"
|
||||
)
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
@ -723,6 +748,7 @@ class Qdrant(VectorStore):
|
||||
embeddings=embedding,
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
distance_strategy=distance_func,
|
||||
vector_name=vector_name,
|
||||
)
|
||||
|
||||
|
@ -121,9 +121,8 @@ class Redis(VectorStore):
|
||||
content_key: str = "content",
|
||||
metadata_key: str = "metadata",
|
||||
vector_key: str = "content_vector",
|
||||
relevance_score_fn: Optional[
|
||||
Callable[[float], float]
|
||||
] = _default_relevance_score,
|
||||
relevance_score_fn: Optional[Callable[[float], float]] = None,
|
||||
distance_metric: REDIS_DISTANCE_METRICS = "COSINE",
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize with necessary components."""
|
||||
@ -149,11 +148,23 @@ class Redis(VectorStore):
|
||||
self.content_key = content_key
|
||||
self.metadata_key = metadata_key
|
||||
self.vector_key = vector_key
|
||||
self.distance_metric = distance_metric
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
|
||||
def _create_index(
|
||||
self, dim: int = 1536, distance_metric: REDIS_DISTANCE_METRICS = "COSINE"
|
||||
) -> None:
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
if self.relevance_score_fn:
|
||||
return self.relevance_score_fn
|
||||
|
||||
if self.distance_metric == "COSINE":
|
||||
return self._cosine_relevance_score_fn
|
||||
elif self.distance_metric == "IP":
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
elif self.distance_metric == "L2":
|
||||
return self._euclidean_relevance_score_fn
|
||||
else:
|
||||
return _default_relevance_score
|
||||
|
||||
def _create_index(self, dim: int = 1536) -> None:
|
||||
try:
|
||||
from redis.commands.search.field import TextField, VectorField
|
||||
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
|
||||
@ -175,7 +186,7 @@ class Redis(VectorStore):
|
||||
{
|
||||
"TYPE": "FLOAT32",
|
||||
"DIM": dim,
|
||||
"DISTANCE_METRIC": distance_metric,
|
||||
"DISTANCE_METRIC": self.distance_metric,
|
||||
},
|
||||
),
|
||||
)
|
||||
@ -347,24 +358,6 @@ class Redis(VectorStore):
|
||||
|
||||
return docs
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
if self.relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"relevance_score_fn must be provided to"
|
||||
" Redis constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k)
|
||||
return [(doc, self.relevance_score_fn(score)) for doc, score in docs_and_scores]
|
||||
|
||||
@classmethod
|
||||
def from_texts_return_keys(
|
||||
cls,
|
||||
@ -413,6 +406,7 @@ class Redis(VectorStore):
|
||||
content_key=content_key,
|
||||
metadata_key=metadata_key,
|
||||
vector_key=vector_key,
|
||||
distance_metric=distance_metric,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -420,7 +414,7 @@ class Redis(VectorStore):
|
||||
embeddings = embedding.embed_documents(texts)
|
||||
|
||||
# Create the search index
|
||||
instance._create_index(dim=len(embeddings[0]), distance_metric=distance_metric)
|
||||
instance._create_index(dim=len(embeddings[0]))
|
||||
|
||||
# Add data to Redis
|
||||
keys = instance.add_texts(texts, metadatas, embeddings)
|
||||
|
@ -2,9 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import json
|
||||
from typing import Any, ClassVar, Collection, Iterable, List, Optional, Tuple, Type
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ClassVar,
|
||||
Collection,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
)
|
||||
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
@ -15,14 +24,7 @@ from langchain.callbacks.manager import (
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.base import VectorStore, VectorStoreRetriever
|
||||
|
||||
|
||||
class DistanceStrategy(str, enum.Enum):
|
||||
"""Enumerator of the Distance strategies for SingleStoreDB."""
|
||||
|
||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
|
||||
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.DOT_PRODUCT
|
||||
|
||||
@ -211,6 +213,9 @@ class SingleStoreDB(VectorStore):
|
||||
)
|
||||
self._create_table()
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return self._max_inner_product_relevance_score_fn
|
||||
|
||||
def _create_table(self: SingleStoreDB) -> None:
|
||||
"""Create table if it doesn't exist."""
|
||||
conn = self.connection_pool.connect()
|
||||
|
@ -1,5 +1,6 @@
|
||||
"""Utility functions for working with vectors and vectorstores."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
@ -7,6 +8,14 @@ import numpy as np
|
||||
from langchain.math_utils import cosine_similarity
|
||||
|
||||
|
||||
class DistanceStrategy(str, Enum):
|
||||
EUCLIDEAN_DISTANCE = "EUCLIDEAN_DISTANCE"
|
||||
MAX_INNER_PRODUCT = "MAX_INNER_PRODUCT"
|
||||
DOT_PRODUCT = "DOT_PRODUCT"
|
||||
JACCARD = "JACCARD"
|
||||
COSINE = "COSINE"
|
||||
|
||||
|
||||
def maximal_marginal_relevance(
|
||||
query_embedding: np.ndarray,
|
||||
embedding_list: list,
|
||||
|
@ -113,11 +113,18 @@ class Weaviate(VectorStore):
|
||||
self._embedding = embedding
|
||||
self._text_key = text_key
|
||||
self._query_attrs = [self._text_key]
|
||||
self._relevance_score_fn = relevance_score_fn
|
||||
self.relevance_score_fn = relevance_score_fn
|
||||
self._by_text = by_text
|
||||
if attributes is not None:
|
||||
self._query_attrs.extend(attributes)
|
||||
|
||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||
return (
|
||||
self.relevance_score_fn
|
||||
if self.relevance_score_fn
|
||||
else _default_score_normalizer
|
||||
)
|
||||
|
||||
def add_texts(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
@ -361,26 +368,6 @@ class Weaviate(VectorStore):
|
||||
docs_and_scores.append((Document(page_content=text, metadata=res), score))
|
||||
return docs_and_scores
|
||||
|
||||
def _similarity_search_with_relevance_scores(
|
||||
self,
|
||||
query: str,
|
||||
k: int = 4,
|
||||
**kwargs: Any,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""Return docs and relevance scores, normalized on a scale from 0 to 1.
|
||||
|
||||
0 is dissimilar, 1 is most similar.
|
||||
"""
|
||||
if self._relevance_score_fn is None:
|
||||
raise ValueError(
|
||||
"relevance_score_fn must be provided to"
|
||||
" Weaviate constructor to normalize scores"
|
||||
)
|
||||
docs_and_scores = self.similarity_search_with_score(query, k=k, **kwargs)
|
||||
return [
|
||||
(doc, self._relevance_score_fn(score)) for doc, score in docs_and_scores
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def from_texts(
|
||||
cls: Type[Weaviate],
|
||||
|
@ -228,3 +228,42 @@ def test_chroma_update_document() -> None:
|
||||
]
|
||||
assert new_embedding == embedding.embed_documents([updated_content])[0]
|
||||
assert new_embedding != old_embedding
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
collection_metadata={"hnsw:space": "l2"},
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.8),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.5),
|
||||
]
|
||||
|
||||
|
||||
def test_chroma_with_relevance_score_custom_normalization_fn() -> None:
|
||||
"""Test searching with relevance score and custom normalization function."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = Chroma.from_texts(
|
||||
collection_name="test_collection",
|
||||
texts=texts,
|
||||
embedding=FakeEmbeddings(),
|
||||
metadatas=metadatas,
|
||||
relevance_score_fn=lambda d: d * 0,
|
||||
collection_metadata={"hnsw:space": "l2"},
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), -0.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), -0.0),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), -0.0),
|
||||
]
|
||||
|
@ -195,3 +195,14 @@ def test_faiss_invalid_normalize_fn() -> None:
|
||||
)
|
||||
with pytest.warns(Warning, match="scores must be between"):
|
||||
docsearch.similarity_search_with_relevance_scores("foo", k=1)
|
||||
|
||||
|
||||
def test_missing_normalize_score_fn() -> None:
|
||||
"""Test doesn't perform similarity search without a valid distance strategy."""
|
||||
with pytest.raises(ValueError):
|
||||
texts = ["foo", "bar", "baz"]
|
||||
faiss_instance = FAISS.from_texts(
|
||||
texts, FakeEmbeddings(), distance_strategy="fake"
|
||||
)
|
||||
|
||||
faiss_instance.similarity_search_with_relevance_scores("foo", k=2)
|
||||
|
@ -184,3 +184,70 @@ def test_pgvector_with_filter_in_set() -> None:
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 0.0),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.0013003906671379406),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_relevance_score() -> None:
|
||||
"""Test to make sure the relevance score is scaled to 0-1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert output == [
|
||||
(Document(page_content="foo", metadata={"page": "0"}), 1.0),
|
||||
(Document(page_content="bar", metadata={"page": "1"}), 0.9996744261675065),
|
||||
(Document(page_content="baz", metadata={"page": "2"}), 0.9986996093328621),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_retriever_search_threshold() -> None:
|
||||
"""Test using retriever for searching with threshold."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.999},
|
||||
)
|
||||
output = retriever.get_relevant_documents("summer")
|
||||
assert output == [
|
||||
Document(page_content="foo", metadata={"page": "0"}),
|
||||
Document(page_content="bar", metadata={"page": "1"}),
|
||||
]
|
||||
|
||||
|
||||
def test_pgvector_retriever_search_threshold_custom_normalization_fn() -> None:
|
||||
"""Test searching with threshold and custom normalization function"""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": str(i)} for i in range(len(texts))]
|
||||
docsearch = PGVector.from_texts(
|
||||
texts=texts,
|
||||
collection_name="test_collection",
|
||||
embedding=FakeEmbeddingsWithAdaDimension(),
|
||||
metadatas=metadatas,
|
||||
connection_string=CONNECTION_STRING,
|
||||
pre_delete_collection=True,
|
||||
relevance_score_fn=lambda d: d * 0,
|
||||
)
|
||||
|
||||
retriever = docsearch.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"k": 3, "score_threshold": 0.5},
|
||||
)
|
||||
output = retriever.get_relevant_documents("foo")
|
||||
assert output == []
|
||||
|
@ -2,6 +2,7 @@ import importlib
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
import numpy as np
|
||||
from typing import List
|
||||
|
||||
import pinecone
|
||||
@ -154,3 +155,21 @@ class TestPinecone:
|
||||
time.sleep(20)
|
||||
index_stats = self.index.describe_index_stats()
|
||||
assert index_stats["total_vector_count"] == len(texts) * 2
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_relevance_score_bound(self, embedding_openai: OpenAIEmbeddings) -> None:
|
||||
"""Ensures all relevance scores are between 0 and 1."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
metadatas = [{"page": i} for i in range(len(texts))]
|
||||
docsearch = Pinecone.from_texts(
|
||||
texts,
|
||||
embedding_openai,
|
||||
index_name=index_name,
|
||||
metadatas=metadatas,
|
||||
)
|
||||
# wait for the index to be ready
|
||||
time.sleep(20)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
assert all(
|
||||
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
||||
)
|
||||
|
@ -1,6 +1,7 @@
|
||||
"""Test Qdrant functionality."""
|
||||
import tempfile
|
||||
from typing import Callable, Optional
|
||||
import numpy as np
|
||||
|
||||
import pytest
|
||||
from qdrant_client.http import models as rest
|
||||
@ -513,3 +514,26 @@ def test_qdrant_add_texts_stores_embeddings_as_named_vectors(vector_name: str) -
|
||||
vector_name in point.vector # type: ignore[operator]
|
||||
for point in client.scroll(collection_name, with_vectors=True)[0]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 64])
|
||||
@pytest.mark.parametrize("content_payload_key", [Qdrant.CONTENT_KEY, "foo"])
|
||||
@pytest.mark.parametrize("metadata_payload_key", [Qdrant.METADATA_KEY, "bar"])
|
||||
def test_qdrant_similarity_search_with_relevance_scores(
|
||||
batch_size: int, content_payload_key: str, metadata_payload_key: str
|
||||
) -> None:
|
||||
"""Test end to end construction and search."""
|
||||
texts = ["foo", "bar", "baz"]
|
||||
docsearch = Qdrant.from_texts(
|
||||
texts,
|
||||
ConsistentFakeEmbeddings(),
|
||||
location=":memory:",
|
||||
content_payload_key=content_payload_key,
|
||||
metadata_payload_key=metadata_payload_key,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
output = docsearch.similarity_search_with_relevance_scores("foo", k=3)
|
||||
|
||||
assert all(
|
||||
(1 >= score or np.isclose(score, 1)) and score >= 0 for _, score in output
|
||||
)
|
||||
|
@ -5,7 +5,8 @@ import numpy as np
|
||||
import pytest
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.vectorstores.singlestoredb import DistanceStrategy, SingleStoreDB
|
||||
from langchain.vectorstores.singlestoredb import SingleStoreDB
|
||||
from langchain.vectorstores.utils import DistanceStrategy
|
||||
from tests.integration_tests.vectorstores.fake_embeddings import FakeEmbeddings
|
||||
|
||||
TEST_SINGLESTOREDB_URL = "root:pass@localhost:3306/db"
|
||||
|
Loading…
Reference in New Issue
Block a user