mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-01 17:13:22 +00:00
fix: impl missing embeddings method (#10823)
FAISS does not implement embeddings method and use embed_query to embedding texts which is wrong for some embedding models. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
2661dc94f3
commit
77fc2f7644
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
@ -15,6 +16,7 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sized,
|
Sized,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,6 +28,8 @@ from langchain.schema.embeddings import Embeddings
|
|||||||
from langchain.schema.vectorstore import VectorStore
|
from langchain.schema.vectorstore import VectorStore
|
||||||
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
from langchain.vectorstores.utils import DistanceStrategy, maximal_marginal_relevance
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
def dependable_faiss_import(no_avx2: Optional[bool] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
@ -82,7 +86,7 @@ class FAISS(VectorStore):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embedding_function: Callable,
|
embedding_function: Union[Callable, Embeddings],
|
||||||
index: Any,
|
index: Any,
|
||||||
docstore: Docstore,
|
docstore: Docstore,
|
||||||
index_to_docstore_id: Dict[int, str],
|
index_to_docstore_id: Dict[int, str],
|
||||||
@ -91,6 +95,11 @@ class FAISS(VectorStore):
|
|||||||
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
distance_strategy: DistanceStrategy = DistanceStrategy.EUCLIDEAN_DISTANCE,
|
||||||
):
|
):
|
||||||
"""Initialize with necessary components."""
|
"""Initialize with necessary components."""
|
||||||
|
if not isinstance(embedding_function, Embeddings):
|
||||||
|
logger.warning(
|
||||||
|
"`embedding_function` is expected to be an Embeddings object, support "
|
||||||
|
"for passing in a function will soon be removed."
|
||||||
|
)
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self.index = index
|
self.index = index
|
||||||
self.docstore = docstore
|
self.docstore = docstore
|
||||||
@ -108,6 +117,26 @@ class FAISS(VectorStore):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def embeddings(self) -> Optional[Embeddings]:
|
||||||
|
return (
|
||||||
|
self.embedding_function
|
||||||
|
if isinstance(self.embedding_function, Embeddings)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
|
if isinstance(self.embedding_function, Embeddings):
|
||||||
|
return self.embedding_function.embed_documents(texts)
|
||||||
|
else:
|
||||||
|
return [self.embedding_function(text) for text in texts]
|
||||||
|
|
||||||
|
def _embed_query(self, text: str) -> List[float]:
|
||||||
|
if isinstance(self.embedding_function, Embeddings):
|
||||||
|
return self.embedding_function.embed_query(text)
|
||||||
|
else:
|
||||||
|
return self.embedding_function(text)
|
||||||
|
|
||||||
def __add(
|
def __add(
|
||||||
self,
|
self,
|
||||||
texts: Iterable[str],
|
texts: Iterable[str],
|
||||||
@ -163,7 +192,8 @@ class FAISS(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of ids from adding the texts into the vectorstore.
|
List of ids from adding the texts into the vectorstore.
|
||||||
"""
|
"""
|
||||||
embeddings = [self.embedding_function(text) for text in texts]
|
texts = list(texts)
|
||||||
|
embeddings = self._embed_documents(texts)
|
||||||
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
return self.__add(texts, embeddings, metadatas=metadatas, ids=ids)
|
||||||
|
|
||||||
def add_embeddings(
|
def add_embeddings(
|
||||||
@ -272,7 +302,7 @@ class FAISS(VectorStore):
|
|||||||
List of documents most similar to the query text with
|
List of documents most similar to the query text with
|
||||||
L2 distance in float. Lower score represents more similarity.
|
L2 distance in float. Lower score represents more similarity.
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function(query)
|
embedding = self._embed_query(query)
|
||||||
docs = self.similarity_search_with_score_by_vector(
|
docs = self.similarity_search_with_score_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k,
|
k,
|
||||||
@ -465,7 +495,7 @@ class FAISS(VectorStore):
|
|||||||
Returns:
|
Returns:
|
||||||
List of Documents selected by maximal marginal relevance.
|
List of Documents selected by maximal marginal relevance.
|
||||||
"""
|
"""
|
||||||
embedding = self.embedding_function(query)
|
embedding = self._embed_query(query)
|
||||||
docs = self.max_marginal_relevance_search_by_vector(
|
docs = self.max_marginal_relevance_search_by_vector(
|
||||||
embedding,
|
embedding,
|
||||||
k=k,
|
k=k,
|
||||||
@ -561,7 +591,7 @@ class FAISS(VectorStore):
|
|||||||
# Default to L2, currently other metric types not initialized.
|
# Default to L2, currently other metric types not initialized.
|
||||||
index = faiss.IndexFlatL2(len(embeddings[0]))
|
index = faiss.IndexFlatL2(len(embeddings[0]))
|
||||||
vecstore = cls(
|
vecstore = cls(
|
||||||
embedding.embed_query,
|
embedding,
|
||||||
index,
|
index,
|
||||||
InMemoryDocstore(),
|
InMemoryDocstore(),
|
||||||
{},
|
{},
|
||||||
@ -696,9 +726,7 @@ class FAISS(VectorStore):
|
|||||||
# load docstore and index_to_docstore_id
|
# load docstore and index_to_docstore_id
|
||||||
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
with open(path / "{index_name}.pkl".format(index_name=index_name), "rb") as f:
|
||||||
docstore, index_to_docstore_id = pickle.load(f)
|
docstore, index_to_docstore_id = pickle.load(f)
|
||||||
return cls(
|
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
|
||||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def serialize_to_bytes(self) -> bytes:
|
def serialize_to_bytes(self) -> bytes:
|
||||||
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
"""Serialize FAISS index, docstore, and index_to_docstore_id to bytes."""
|
||||||
@ -713,9 +741,7 @@ class FAISS(VectorStore):
|
|||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
"""Deserialize FAISS index, docstore, and index_to_docstore_id from bytes."""
|
||||||
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
index, docstore, index_to_docstore_id = pickle.loads(serialized)
|
||||||
return cls(
|
return cls(embeddings, index, docstore, index_to_docstore_id, **kwargs)
|
||||||
embeddings.embed_query, index, docstore, index_to_docstore_id, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
def _select_relevance_score_fn(self) -> Callable[[float], float]:
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user