diff --git a/pilot/model/loader.py b/pilot/model/loader.py index a6019c129..63a484151 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -18,6 +18,7 @@ from pilot.logs import logger def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): # TODO: vicuna-v1.5 8-bit quantization info is slow # TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5 + # TODO: support internlm quantization model_name = model_params.model_name.lower() supported_models = ["llama", "baichuan", "vicuna"] return any(m in model_name for m in supported_models) diff --git a/pilot/utils/model_utils.py b/pilot/utils/model_utils.py index a7a51ad32..037e3a021 100644 --- a/pilot/utils/model_utils.py +++ b/pilot/utils/model_utils.py @@ -2,10 +2,12 @@ import logging def _clear_torch_cache(device="cuda"): + try: + import torch + except ImportError: + return import gc - import torch - gc.collect() if device != "cpu": if torch.has_mps: