mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 11:29:29 +00:00
chore: fix shutdown error when not install torch
This commit is contained in:
parent
c830598c9e
commit
896af4e16f
@ -18,6 +18,7 @@ from pilot.logs import logger
|
|||||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParameters):
|
||||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||||
|
# TODO: support internlm quantization
|
||||||
model_name = model_params.model_name.lower()
|
model_name = model_params.model_name.lower()
|
||||||
supported_models = ["llama", "baichuan", "vicuna"]
|
supported_models = ["llama", "baichuan", "vicuna"]
|
||||||
return any(m in model_name for m in supported_models)
|
return any(m in model_name for m in supported_models)
|
||||||
|
@ -2,10 +2,12 @@ import logging
|
|||||||
|
|
||||||
|
|
||||||
def _clear_torch_cache(device="cuda"):
|
def _clear_torch_cache(device="cuda"):
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
return
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
if device != "cpu":
|
if device != "cpu":
|
||||||
if torch.has_mps:
|
if torch.has_mps:
|
||||||
|
Loading…
Reference in New Issue
Block a user