mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community: add support for using GPUs with FastEmbedEmbeddings (#29627)
- **Description:** add a `gpu: bool = False` field to the `FastEmbedEmbeddings` class which enables to use GPU (through ONNX CUDA provider) when generating embeddings with any fastembed model. It just requires the user to install a different dependency and we use a different provider when instantiating `fastembed.TextEmbedding` - **Issue:** when generating embeddings for a really large amount of documents this drastically increase performance (honestly that is a must have in some situations, you can't just use CPU it is way too slow) - **Dependencies:** no direct change to dependencies, but internally the users will need to install `fastembed-gpu` instead of `fastembed`, I made all the changes to the init function to properly let the user know which dependency they should install depending on if they enabled `gpu` or not cf. fastembed docs about GPU for more details: https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/ I did not added test because it would require access to a GPU in the testing environment
This commit is contained in:
parent
0ceda557aa
commit
0ac5536f04
@ -65,6 +65,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
Defaults to `None`.
|
||||
"""
|
||||
|
||||
gpu: bool = False
|
||||
"""Enable the use of GPU through CUDA. This requires to install `fastembed-gpu`
|
||||
instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
|
||||
for more details.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
model: Any = None # : :meta private:
|
||||
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
@ -76,19 +83,22 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
max_length = values.get("max_length")
|
||||
cache_dir = values.get("cache_dir")
|
||||
threads = values.get("threads")
|
||||
gpu = values.get("gpu")
|
||||
pkg_to_import = "fastembed-gpu" if gpu else "fastembed"
|
||||
|
||||
try:
|
||||
fastembed = importlib.import_module("fastembed")
|
||||
fastembed = importlib.import_module(pkg_to_import)
|
||||
|
||||
except ModuleNotFoundError:
|
||||
raise ImportError(
|
||||
"Could not import 'fastembed' Python package. "
|
||||
"Please install it with `pip install fastembed`."
|
||||
f"Could not import '{pkg_to_import}' Python package. "
|
||||
f"Please install it with `pip install {pkg_to_import}`."
|
||||
)
|
||||
|
||||
if importlib.metadata.version("fastembed") < MIN_VERSION:
|
||||
if importlib.metadata.version(pkg_to_import) < MIN_VERSION:
|
||||
raise ImportError(
|
||||
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
||||
f"FastEmbedEmbeddings requires "
|
||||
f'`pip install -U "{pkg_to_import}>={MIN_VERSION}"`.'
|
||||
)
|
||||
|
||||
values["model"] = fastembed.TextEmbedding(
|
||||
@ -96,6 +106,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
||||
max_length=max_length,
|
||||
cache_dir=cache_dir,
|
||||
threads=threads,
|
||||
providers=["CUDAExecutionProvider"] if gpu else None,
|
||||
)
|
||||
return values
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user