From 896af4e16f8e6ccb1283ac6b963eb21d588861f6 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Thu, 21 Sep 2023 12:20:18 +0800 Subject: [PATCH] chore: fix shutdown error when not install torch --- pilot/model/loader.py | 1 + pilot/utils/model_utils.py | 6 ++++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pilot/model/loader.py b/pilot/model/loader.py index a6019c129..63a484151 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -18,6 +18,7 @@ from pilot.logs import logger def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters): # TODO: vicuna-v1.5 8-bit quantization info is slow # TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5 + # TODO: support internlm quantization model_name = model_params.model_name.lower() supported_models = ["llama", "baichuan", "vicuna"] return any(m in model_name for m in supported_models) diff --git a/pilot/utils/model_utils.py b/pilot/utils/model_utils.py index a7a51ad32..037e3a021 100644 --- a/pilot/utils/model_utils.py +++ b/pilot/utils/model_utils.py @@ -2,10 +2,12 @@ import logging def _clear_torch_cache(device="cuda"): + try: + import torch + except ImportError: + return import gc - import torch - gc.collect() if device != "cpu": if torch.has_mps: