From 60ecde5892e8ecc25a85a9e9bd0f46e5906f8397 Mon Sep 17 00:00:00 2001 From: yihong0618 Date: Wed, 24 May 2023 12:33:41 +0800 Subject: [PATCH] fix: can not answer on mac m1-> mps device --- pilot/model/loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 531080314..bd31bae0a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -9,6 +9,7 @@ from typing import Optional from pilot.model.compression import compress_module from pilot.model.adapter import get_llm_model_adapter from pilot.utils import get_gpu_memory +from pilot.configs.model_config import DEVICE from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations def raise_warning_for_incompatible_cpu_offloading_configuration( @@ -50,7 +51,7 @@ class ModelLoader(metaclass=Singleton): def __init__(self, model_path) -> None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = DEVICE self.model_path = model_path self.kwargs = { "torch_dtype": torch.float16,