mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community: add support for using GPUs with FastEmbedEmbeddings (#29627)
- **Description:** add a `gpu: bool = False` field to the `FastEmbedEmbeddings` class which enables to use GPU (through ONNX CUDA provider) when generating embeddings with any fastembed model. It just requires the user to install a different dependency and we use a different provider when instantiating `fastembed.TextEmbedding` - **Issue:** when generating embeddings for a really large amount of documents this drastically increase performance (honestly that is a must have in some situations, you can't just use CPU it is way too slow) - **Dependencies:** no direct change to dependencies, but internally the users will need to install `fastembed-gpu` instead of `fastembed`, I made all the changes to the init function to properly let the user know which dependency they should install depending on if they enabled `gpu` or not cf. fastembed docs about GPU for more details: https://qdrant.github.io/fastembed/examples/FastEmbed_GPU/ I did not added test because it would require access to a GPU in the testing environment
This commit is contained in:
parent
0ceda557aa
commit
0ac5536f04
@ -65,6 +65,13 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
Defaults to `None`.
|
Defaults to `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
gpu: bool = False
|
||||||
|
"""Enable the use of GPU through CUDA. This requires to install `fastembed-gpu`
|
||||||
|
instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
|
||||||
|
for more details.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
model: Any = None # : :meta private:
|
model: Any = None # : :meta private:
|
||||||
|
|
||||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||||
@ -76,19 +83,22 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
max_length = values.get("max_length")
|
max_length = values.get("max_length")
|
||||||
cache_dir = values.get("cache_dir")
|
cache_dir = values.get("cache_dir")
|
||||||
threads = values.get("threads")
|
threads = values.get("threads")
|
||||||
|
gpu = values.get("gpu")
|
||||||
|
pkg_to_import = "fastembed-gpu" if gpu else "fastembed"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fastembed = importlib.import_module("fastembed")
|
fastembed = importlib.import_module(pkg_to_import)
|
||||||
|
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Could not import 'fastembed' Python package. "
|
f"Could not import '{pkg_to_import}' Python package. "
|
||||||
"Please install it with `pip install fastembed`."
|
f"Please install it with `pip install {pkg_to_import}`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if importlib.metadata.version("fastembed") < MIN_VERSION:
|
if importlib.metadata.version(pkg_to_import) < MIN_VERSION:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
'FastEmbedEmbeddings requires `pip install -U "fastembed>=0.2.0"`.'
|
f"FastEmbedEmbeddings requires "
|
||||||
|
f'`pip install -U "{pkg_to_import}>={MIN_VERSION}"`.'
|
||||||
)
|
)
|
||||||
|
|
||||||
values["model"] = fastembed.TextEmbedding(
|
values["model"] = fastembed.TextEmbedding(
|
||||||
@ -96,6 +106,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
|
providers=["CUDAExecutionProvider"] if gpu else None,
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user