mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-20 13:54:48 +00:00
community: improve FastEmbedEmbeddings support for ONNX execution provider (e.g. GPU) (#29645)
I made a change to how was implemented the support for GPU in `FastEmbedEmbeddings` to be more consistent with the existing implementation `langchain-qdrant` sparse embeddings implementation It is directly enabling to provide the list of ONNX execution providers: https://github.com/langchain-ai/langchain/blob/master/libs/partners/qdrant/langchain_qdrant/fastembed_sparse.py#L15 It is a bit less clear to a user that just wants to enable GPU, but gives more capabilities to work with other execution providers that are not the `CUDAExecutionProvider`, and is more future proof Sorry for the disturbance @ccurme > Nice to see you just moved to `uv`! It is so much nicer to run format/lint/test! No need to manually rerun the `poetry install` with all required extras now
This commit is contained in:
parent
1bf620222b
commit
08b9eaaa6f
@ -1,6 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
from typing import Any, Dict, List, Literal, Optional, cast
|
from typing import Any, Dict, List, Literal, Optional, Sequence, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
@ -65,11 +65,12 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
Defaults to `None`.
|
Defaults to `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gpu: bool = False
|
providers: Optional[Sequence[Any]] = None
|
||||||
"""Enable the use of GPU through CUDA. This requires to install `fastembed-gpu`
|
"""List of ONNX execution providers. Use `["CUDAExecutionProvider"]` to enable the
|
||||||
|
use of GPU when generating embeddings. This requires to install `fastembed-gpu`
|
||||||
instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
|
instead of `fastembed`. See https://qdrant.github.io/fastembed/examples/FastEmbed_GPU
|
||||||
for more details.
|
for more details.
|
||||||
Defaults to False.
|
Defaults to `None`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: Any = None # : :meta private:
|
model: Any = None # : :meta private:
|
||||||
@ -83,8 +84,12 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
max_length = values.get("max_length")
|
max_length = values.get("max_length")
|
||||||
cache_dir = values.get("cache_dir")
|
cache_dir = values.get("cache_dir")
|
||||||
threads = values.get("threads")
|
threads = values.get("threads")
|
||||||
gpu = values.get("gpu")
|
providers = values.get("providers")
|
||||||
pkg_to_install = "fastembed-gpu" if gpu else "fastembed"
|
pkg_to_install = (
|
||||||
|
"fastembed-gpu"
|
||||||
|
if providers and "CUDAExecutionProvider" in providers
|
||||||
|
else "fastembed"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fastembed = importlib.import_module("fastembed")
|
fastembed = importlib.import_module("fastembed")
|
||||||
@ -106,7 +111,7 @@ class FastEmbedEmbeddings(BaseModel, Embeddings):
|
|||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
threads=threads,
|
threads=threads,
|
||||||
providers=["CUDAExecutionProvider"] if gpu else None,
|
providers=providers,
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user