From 743863d52baab8c369871e4ea6f4833a736c34bd Mon Sep 17 00:00:00 2001 From: csunny Date: Thu, 3 Aug 2023 16:54:58 +0800 Subject: [PATCH] fix: set num_gpus reference for mps + cpu --- pilot/model/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 9631d7a0a..c1af72e6e 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -117,8 +117,8 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): max_memory = None # if device is cpu or mps. gpu need to be zero - num_gpus = 0 - + num_gpus = 0 + if device == "cpu": kwargs = {"torch_dtype": torch.float32} elif device == "cuda":