fix: num_gpus referenced error for mps + cpu (#407)

close #406
This commit is contained in:
Aries-ckt
2023-08-03 17:56:33 +08:00
committed by GitHub

View File

@@ -115,6 +115,10 @@ class ModelLoader(metaclass=Singleton):
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
device = model_params.device
max_memory = None
# if device is cpu or mps. gpu need to be zero
num_gpus = 0
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":