Files
DB-GPT/pilot/model/loader.py
2023-08-03 16:52:39 +08:00

315 lines
11 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from typing import Optional, Dict
import dataclasses
import torch
from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.compression import compress_module
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.singleton import Singleton
from pilot.utils import get_gpu_memory
from pilot.logs import logger
class ModelType:
""" "Type of model"""
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
# TODO, support more model type
@dataclasses.dataclass
class ModelParams:
device: str
model_name: str
model_path: str
model_type: Optional[str] = ModelType.HF
num_gpus: Optional[int] = None
max_gpu_memory: Optional[str] = None
cpu_offloading: Optional[bool] = False
load_8bit: Optional[bool] = True
load_4bit: Optional[bool] = False
# quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float)
quant_type: Optional[str] = "nf4"
# Nested quantization is activated through `use_double_quant``
use_double_quant: Optional[bool] = True
# "bfloat16", "float16", "float32"
compute_dtype: Optional[str] = None
debug: Optional[bool] = False
trust_remote_code: Optional[bool] = True
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)
def _check_quantization(model_params: ModelParams):
model_name = model_params.model_name.lower()
has_quantization = any([model_params.load_8bit or model_params.load_4bit])
if has_quantization:
if model_params.device != "cuda":
logger.warn(
"8-bit quantization and 4-bit quantization just supported by cuda"
)
return False
elif "chatglm" in model_name:
if "int4" not in model_name:
logger.warn(
"chatglm or chatglm2 not support quantization now, see: https://github.com/huggingface/transformers/issues/25228"
)
return False
return has_quantization
class ModelLoader(metaclass=Singleton):
"""Model loader is a class for model load
Args: model_path
TODO: multi model support.
"""
kwargs = {}
def __init__(self, model_path: str, model_name: str = None) -> None:
self.device = DEVICE
self.model_path = model_path
self.model_name = model_name
self.kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
}
# TODO multi gpu support
def loader(
self,
load_8bit=False,
load_4bit=False,
debug=False,
cpu_offloading=False,
max_gpu_memory: Optional[str] = None,
):
model_params = ModelParams(
device=self.device,
model_path=self.model_path,
model_name=self.model_name,
max_gpu_memory=max_gpu_memory,
cpu_offloading=cpu_offloading,
load_8bit=load_8bit,
load_4bit=load_4bit,
debug=debug,
)
logger.info(f"model_params:\n{model_params}")
llm_adapter = get_llm_model_adapter(model_params.model_path)
return huggingface_loader(llm_adapter, model_params)
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
device = model_params.device
max_memory = None
# if device is cpu or mps. gpu need to be zero
num_gpus = 0
if device == "cpu":
kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
num_gpus = torch.cuda.device_count()
available_gpu_memory = get_gpu_memory(num_gpus)
max_memory = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)
}
if num_gpus != 1:
kwargs["device_map"] = "auto"
if model_params.max_gpu_memory:
logger.info(
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
)
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
kwargs["max_memory"] = max_memory
else:
kwargs["max_memory"] = max_memory
logger.debug(f"max_memory: {max_memory}")
elif device == "mps":
kwargs = {"torch_dtype": torch.float16}
replace_llama_attn_with_non_inplace_operations()
else:
raise ValueError(f"Invalid device: {device}")
can_quantization = _check_quantization(model_params)
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
if _check_multi_gpu_or_4bit_quantization(model_params):
return load_huggingface_quantization_model(
llm_adapter, model_params, kwargs, max_memory
)
else:
logger.warn(
f"Current model {model_params.model_name} not supported quantization"
)
# default loader
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs)
if model_params.load_8bit and num_gpus == 1 and tokenizer:
# TODO merge current code into `load_huggingface_quantization_model`
compress_module(model, model_params.device)
if (
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
or device == "mps"
and tokenizer
):
try:
model.to(device)
except ValueError:
pass
except AttributeError:
pass
if model_params.debug:
print(model)
return model, tokenizer
def load_huggingface_quantization_model(
llm_adapter: BaseLLMAdaper,
model_params: ModelParams,
kwargs: Dict,
max_memory: Dict[int, str],
):
try:
from accelerate import init_empty_weights
from accelerate.utils import infer_auto_device_map
import transformers
from transformers import (
BitsAndBytesConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
LlamaForCausalLM,
AutoModelForSeq2SeqLM,
LlamaTokenizer,
AutoTokenizer,
)
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers` "
"`pip install bitsandbytes``pip install accelerate`."
) from exc
if (
"llama-2" in model_params.model_name.lower()
and not transformers.__version__ >= "4.31.0"
):
raise ValueError(
"Llama-2 quantization require transformers.__version__>=4.31.0"
)
params = {"low_cpu_mem_usage": True}
params["low_cpu_mem_usage"] = True
params["device_map"] = "auto"
torch_dtype = kwargs.get("torch_dtype")
if model_params.load_4bit:
compute_dtype = None
if model_params.compute_dtype and model_params.compute_dtype in [
"bfloat16",
"float16",
"float32",
]:
compute_dtype = eval("torch.{}".format(model_params.compute_dtype))
quantization_config_params = {
"load_in_4bit": True,
"bnb_4bit_compute_dtype": compute_dtype,
"bnb_4bit_quant_type": model_params.quant_type,
"bnb_4bit_use_double_quant": model_params.use_double_quant,
}
logger.warn(
"Using the following 4-bit params: " + str(quantization_config_params)
)
params["quantization_config"] = BitsAndBytesConfig(**quantization_config_params)
elif model_params.load_8bit and max_memory:
params["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True
)
elif model_params.load_in_8bit:
params["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
params["torch_dtype"] = torch_dtype if torch_dtype else torch.float16
params["max_memory"] = max_memory
if "chatglm" in model_params.model_name.lower():
LoaderClass = AutoModel
else:
config = AutoConfig.from_pretrained(
model_params.model_path, trust_remote_code=model_params.trust_remote_code
)
if config.to_dict().get("is_encoder_decoder", False):
LoaderClass = AutoModelForSeq2SeqLM
else:
LoaderClass = AutoModelForCausalLM
if model_params.load_8bit and max_memory is not None:
config = AutoConfig.from_pretrained(
model_params.model_path, trust_remote_code=model_params.trust_remote_code
)
with init_empty_weights():
model = LoaderClass.from_config(
config, trust_remote_code=model_params.trust_remote_code
)
model.tie_weights()
params["device_map"] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params["max_memory"].copy(),
no_split_module_classes=model._no_split_modules,
)
try:
if model_params.trust_remote_code:
params["trust_remote_code"] = True
logger.info(f"params: {params}")
model = LoaderClass.from_pretrained(model_params.model_path, **params)
except Exception as e:
logger.error(
f"Load quantization model failed, error: {str(e)}, params: {params}"
)
raise e
# Loading the tokenizer
if type(model) is LlamaForCausalLM:
logger.info(
f"Current model is type of: LlamaForCausalLM, load tokenizer by LlamaTokenizer"
)
tokenizer = LlamaTokenizer.from_pretrained(
model_params.model_path, clean_up_tokenization_spaces=True
)
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error.
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except Exception as e:
logger.warn(f"{str(e)}")
else:
logger.info(
f"Current model type is not LlamaForCausalLM, load tokenizer by AutoTokenizer"
)
tokenizer = AutoTokenizer.from_pretrained(
model_params.model_path,
trust_remote_code=model_params.trust_remote_code,
use_fast=llm_adapter.use_fast_tokenizer(),
)
return model, tokenizer