fix: num_gpus referenced error for mps + cpu (#407)

close #406
This commit is contained in:
Aries-ckt 2023-08-03 17:56:33 +08:00 committed by GitHub
commit e9dfabe9fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -115,6 +115,10 @@ class ModelLoader(metaclass=Singleton):
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
device = model_params.device
max_memory = None
# if device is cpu or mps. gpu need to be zero
num_gpus = 0
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":