gpufix partial

This commit is contained in:
Stephen Gresham 2024-03-21 08:53:36 +11:00
parent 207531d555
commit 16ee058fa2
3 changed files with 15 additions and 9 deletions

View File

@ -41,16 +41,14 @@ class EmbeddingComponent:
# Check if CUDA is available
if torch.cuda.is_available():
# If settings.embedding.gpu is specified, use that GPU index
if hasattr(settings, 'embedding') and hasattr(settings.embedding, 'gpu'):
gpu_index = settings.embedding.gpu
device = torch.device(f"cuda:{gpu_index}")
if hasattr(settings, 'huggingface') and hasattr(settings.huggingface, 'gpu_type'):
device = torch.device(f"{settings.huggingface.gpu_type}:{settings.huggingface.gpu_number}")
else:
# Use the default GPU (index 0)
device = torch.device("cuda:0")
device = torch.device('cuda:0')
else:
# If CUDA is not available, use CPU
device = torch.device("cpu")
print("Embedding Device: ",device)
self.embedding_model = HuggingFaceEmbedding(
model_name=settings.huggingface.embedding_hf_model_name,
cache_folder=str(models_cache_path),

View File

@ -1,4 +1,4 @@
from typing import Literal
from typing import Literal, Optional
from pydantic import BaseModel, Field
@ -145,6 +145,13 @@ class HuggingFaceSettings(BaseModel):
embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings"
)
gpu_type: Optional[Literal["cuda","cpu"]] = Field(
description="GPU typedevice for embedding, can be 'cuda' or cpu"
)
gpu_number: int = Field(
0,
description="GPU device number for embedding, will be presented to torch like 'cuda:x'"
)
class EmbeddingSettings(BaseModel):

View File

@ -54,10 +54,11 @@ embedding:
# Should be matching the value above in most cases
mode: huggingface
ingest_mode: simple
# gpu: cuda[0] # if you have more than one GPU and you want to select another. defaults to cuda[0], or cpu if cuda not available
huggingface:
embedding_hf_model_name: BAAI/bge-small-en-v1.5
gpu_type: cuda #GPU typedevice for embedding, can be 'cuda', rocm or cpu". defaults to cuda[0], or cpu if cuda not available
gpu_number: 1 #Directly select a device, normally 0 if only a single GPU
vectorstore:
database: qdrant