diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 7d33d8c6a..c1af72e6e 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -115,6 +115,10 @@ class ModelLoader(metaclass=Singleton): def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): device = model_params.device max_memory = None + + # if device is cpu or mps. gpu need to be zero + num_gpus = 0 + if device == "cpu": kwargs = {"torch_dtype": torch.float32} elif device == "cuda":