community: Fix FastEmbedEmbeddings (#24462)

## Description

This PR:
- Fixes the validation error in `FastEmbedEmbeddings`.
- Adds support for `batch_size`, `parallel` params.
- Removes support for very old FastEmbed versions.
- Updates the FastEmbed doc with the new params.

Associated Issues:
- Resolves #24039
- Resolves #https://github.com/qdrant/fastembed/issues/296
This commit is contained in:
Anush
2024-07-30 22:12:46 +05:30
committed by GitHub
parent 73ec24fc56
commit 51b15448cc
4 changed files with 69 additions and 35 deletions

View File

@@ -1,3 +1,5 @@
import importlib
import importlib.metadata
from typing import Any, Dict, List, Literal, Optional
import numpy as np
@@ -5,6 +7,8 @@ from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra
from langchain_core.utils import pre_init
MIN_VERSION = "0.2.0"
class FastEmbedEmbeddings(BaseModel, Embeddings):
"""Qdrant FastEmbedding models.
@@ -48,12 +52,24 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
The available options are: "default" and "passage"
"""
batch_size: int = 256
"""Batch size for encoding. Higher values will use more memory, but be faster.
Defaults to 256.
"""
parallel: Optional[int] = None
"""If `>1`, parallel encoding is used, recommended for encoding of large datasets.
If `0`, use all available cores.
If `None`, don't use data-parallel processing, use default onnxruntime threading.
Defaults to `None`.
"""
_model: Any # : :meta private:
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
extra = Extra.allow
@pre_init
def validate_environment(cls, values: Dict) -> Dict:
@@ -64,31 +80,25 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
threads = values.get("threads")
try:
# >= v0.2.0
from fastembed import TextEmbedding
fastembed = importlib.import_module("fastembed")
values["_model"] = TextEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
except ModuleNotFoundError:
raise ImportError(
"Could not import 'fastembed' Python package. "
"Please install it with `pip install fastembed`."
)
except ImportError as ie:
try:
# < v0.2.0
from fastembed.embedding import FlagEmbedding
values["_model"] = FlagEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
)
except ImportError:
raise ImportError(
"Could not import 'fastembed' Python package. "
"Please install it with `pip install fastembed`."
) from ie
if importlib.metadata.version("fastembed") < MIN_VERSION:
raise ImportError(
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
)
values["_model"] = fastembed.TextEmbedding(
model_name=model_name,
max_length=max_length,
cache_dir=cache_dir,
threads=threads,
)
return values
def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -102,9 +112,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
"""
embeddings: List[np.ndarray]
if self.doc_embed_type == "passage":
embeddings = self._model.passage_embed(texts)
embeddings = self._model.passage_embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
else:
embeddings = self._model.embed(texts)
embeddings = self._model.embed(
texts, batch_size=self.batch_size, parallel=self.parallel
)
return [e.tolist() for e in embeddings]
def embed_query(self, text: str) -> List[float]:
@@ -116,5 +130,9 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
query_embeddings: np.ndarray = next(self._model.query_embed(text))
query_embeddings: np.ndarray = next(
self._model.query_embed(
text, batch_size=self.batch_size, parallel=self.parallel
)
)
return query_embeddings.tolist()