diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 531080314..bd31bae0a 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -9,6 +9,7 @@ from typing import Optional from pilot.model.compression import compress_module from pilot.model.adapter import get_llm_model_adapter from pilot.utils import get_gpu_memory +from pilot.configs.model_config import DEVICE from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations def raise_warning_for_incompatible_cpu_offloading_configuration( @@ -50,7 +51,7 @@ class ModelLoader(metaclass=Singleton): def __init__(self, model_path) -> None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = DEVICE self.model_path = model_path self.kwargs = { "torch_dtype": torch.float16,