mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-23 19:39:58 +00:00
community: Fix FastEmbedEmbeddings (#24462)
## Description This PR: - Fixes the validation error in `FastEmbedEmbeddings`. - Adds support for `batch_size`, `parallel` params. - Removes support for very old FastEmbed versions. - Updates the FastEmbed doc with the new params. Associated Issues: - Resolves #24039 - Resolves #https://github.com/qdrant/fastembed/issues/296
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import importlib
|
||||
import importlib.metadata
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -5,6 +7,8 @@ from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra
|
||||
from langchain_core.utils import pre_init
|
||||
|
||||
MIN_VERSION = "0.2.0"
|
||||
|
||||
|
||||
class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
"""Qdrant FastEmbedding models.
|
||||
@@ -48,12 +52,24 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
The available options are: "default" and "passage"
|
||||
"""
|
||||
|
||||
batch_size: int = 256
|
||||
"""Batch size for encoding. Higher values will use more memory, but be faster.
|
||||
Defaults to 256.
|
||||
"""
|
||||
|
||||
parallel: Optional[int] = None
|
||||
"""If `>1`, parallel encoding is used, recommended for encoding of large datasets.
|
||||
If `0`, use all available cores.
|
||||
If `None`, don't use data-parallel processing, use default onnxruntime threading.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
|
||||
_model: Any # : :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
extra = Extra.allow
|
||||
|
||||
@pre_init
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@@ -64,31 +80,25 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
threads = values.get("threads")
|
||||
|
||||
try:
|
||||
# >= v0.2.0
|
||||
from fastembed import TextEmbedding
|
||||
fastembed = importlib.import_module("fastembed")
|
||||
|
||||
values["_model"] = TextEmbedding(
|
||||
model_name=model_name,
|
||||
max_length=max_length,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"Could not import 'fastembed' Python package. "
|
||||
"Please install it with `pip install fastembed`."
|
||||
)
|
||||
except ImportError as ie:
|
||||
try:
|
||||
# < v0.2.0
|
||||
from fastembed.embedding import FlagEmbedding
|
||||
|
||||
values["_model"] = FlagEmbedding(
|
||||
model_name=model_name,
|
||||
max_length=max_length,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import 'fastembed' Python package. "
|
||||
"Please install it with `pip install fastembed`."
|
||||
) from ie
|
||||
if importlib.metadata.version("fastembed") < MIN_VERSION:
|
||||
raise ImportError(
|
||||
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
||||
)
|
||||
|
||||
values["_model"] = fastembed.TextEmbedding(
|
||||
model_name=model_name,
|
||||
max_length=max_length,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
)
|
||||
return values
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
@@ -102,9 +112,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
embeddings: List[np.ndarray]
|
||||
if self.doc_embed_type == "passage":
|
||||
embeddings = self._model.passage_embed(texts)
|
||||
embeddings = self._model.passage_embed(
|
||||
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||
)
|
||||
else:
|
||||
embeddings = self._model.embed(texts)
|
||||
embeddings = self._model.embed(
|
||||
texts, batch_size=self.batch_size, parallel=self.parallel
|
||||
)
|
||||
return [e.tolist() for e in embeddings]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
@@ -116,5 +130,9 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
query_embeddings: np.ndarray = next(self._model.query_embed(text))
|
||||
query_embeddings: np.ndarray = next(
|
||||
self._model.query_embed(
|
||||
text, batch_size=self.batch_size, parallel=self.parallel
|
||||
)
|
||||
)
|
||||
return query_embeddings.tolist()
|
||||
|
Reference in New Issue
Block a user