fix: impl missing embeddings method (#10823)

FAISS does not implement embeddings method and use embed_query to
embedding texts which is wrong for some embedding models.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Sian Cao 2023-10-19 14:51:28 +08:00 committed by GitHub
parent 2661dc94f3
commit 77fc2f7644
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import logging
import operator import operator
import os import os
import pickle import pickle
@ -15,6 +16,7 @@ from typing import (
Optional, Optional,
Sized, Sized,
Tuple, Tuple,
Union,
) )
import numpy as np import numpy as np
@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings
from langchain.schema.vectorstore import VectorStore from langchain.schema.vectorstore import VectorStore
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
logger = logging.getLogger(__name__)
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any: def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
""" """
@ -82,7 +86,7 @@ class FAISS(VectorStore):
def __init__( def __init__(
self, self,
embedding_function: Callable, embedding_function: Union[Callable, Embeddings],
index: Any, index: Any,
docstore: Docstore, docstore: Docstore,
index_to_docstore_id: Dict[int, str], index_to_docstore_id: Dict[int, str],
@ -91,6 +95,11 @@ class FAISS(VectorStore):
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE, distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
): ):
"""Initialize with necessary components.""" """Initialize with necessary components."""
if not isinstance(embedding_function, Embeddings):
logger.warning(
"`embedding_function` is expected to be an Embeddings object, support "
"for passing in a function will soon be removed."
)
self.embedding_function = embedding_function self.embedding_function = embedding_function
self.index = index self.index = index
self.docstore = docstore self.docstore = docstore
@ -108,6 +117,26 @@ class FAISS(VectorStore):
) )
) )
@property
def embeddings(self) -> Optional[Embeddings]:
return (
self.embedding_function
if isinstance(self.embedding_function, Embeddings)
else None
)
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_documents(texts)
else:
return [self.embedding_function(text) for text in texts]
def _embed_query(self, text: str) -> List[float]:
if isinstance(self.embedding_function, Embeddings):
return self.embedding_function.embed_query(text)
else:
return self.embedding_function(text)
def __add( def __add(
self, self,
texts: Iterable[str], texts: Iterable[str],
@ -163,7 +192,8 @@ class FAISS(VectorStore):
Returns: Returns:
List of ids from adding the texts into the vectorstore. List of ids from adding the texts into the vectorstore.
""" """
embeddings = [self.embedding_function(text) for text in texts] texts = list(texts)
embeddings = self._embed_documents(texts)
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids) return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
def add_embeddings( def add_embeddings(
@ -272,7 +302,7 @@ class FAISS(VectorStore):
List of documents most similar to the query text with List of documents most similar to the query text with
L2 distance in float. Lower score represents more similarity. L2 distance in float. Lower score represents more similarity.
""" """
embedding = self.embedding_function(query) embedding = self._embed_query(query)
docs = self.similarity_search_with_score_by_vector( docs = self.similarity_search_with_score_by_vector(
embedding, embedding,
k, k,
@ -465,7 +495,7 @@ class FAISS(VectorStore):
Returns: Returns:
List of Documents selected by maximal marginal relevance. List of Documents selected by maximal marginal relevance.
""" """
embedding = self.embedding_function(query) embedding = self._embed_query(query)
docs = self.max_marginal_relevance_search_by_vector( docs = self.max_marginal_relevance_search_by_vector(
embedding, embedding,
k=k, k=k,
@ -561,7 +591,7 @@ class FAISS(VectorStore):
# Default to L2, currently other metric types not initialized. # Default to L2, currently other metric types not initialized.
index = faiss.IndexFlatL2(len(embeddings[0])) index = faiss.IndexFlatL2(len(embeddings[0]))
vecstore = cls( vecstore = cls(
embedding.embed_query, embedding,
index, index,
InMemoryDocstore(), InMemoryDocstore(),
{}, {},
@ -696,9 +726,7 @@ class FAISS(VectorStore):
# load docstore and index_to_docstore_id # load docstore and index_to_docstore_id
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f: with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
docstore, index_to_docstore_id = pickle.load(f) docstore, index_to_docstore_id = pickle.load(f)
return cls( return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
def serialize_to_bytes(self) -> bytes: def serialize_to_bytes(self) -> bytes:
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes.""" """Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
@ -713,9 +741,7 @@ class FAISS(VectorStore):
) -> FAISS: ) -> FAISS:
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes.""" """Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
index, docstore, index_to_docstore_id = pickle.loads(serialized) index, docstore, index_to_docstore_id = pickle.loads(serialized)
return cls( return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
)
def _select_relevance_score_fn(self) -> Callable[[float], float]: def _select_relevance_score_fn(self) -> Callable[[float], float]:
""" """