From d67a6a642abf89bbf25499fefab8ec4e893b1852 Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 3 Aug 2023 16:52:39 +0800 Subject: [PATCH] fix: num_gpus referenced error for mps + cpu --- pilot/model/loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 7d33d8c6a..9631d7a0a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -115,6 +115,10 @@ class ModelLoader(metaclass=Singleton): def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): device = model_params.device max_memory = None + + # if device is cpu or mps. gpu need to be zero + num_gpus = 0 + if device == "cpu": kwargs = {"torch_dtype": torch.float32} elif device == "cuda":