feat: Support 8-bit quantization and 4-bit quantization for multi-gpu inference

This commit is contained in:
FangYin Cheng 2023-08-02 15:51:57 +08:00
parent e16a5ccfc9
commit d8a4b776d5
8 changed files with 368 additions and 93 deletions

View File

@ -1,25 +1,48 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
FROM ${BASE_IMAGE}
ARG BASE_IMAGE
RUN apt-get update && apt-get install -y git python3 pip wget \ RUN apt-get update && apt-get install -y git python3 pip wget \
&& apt-get clean && apt-get clean
# download code from githu: https://github.com/csunny/DB-GPT ARG BUILD_LOCAL_CODE="false"
# ENV DBGPT_VERSION="v0.3.3" ARG LANGUAGE="en"
# RUN wget https://github.com/csunny/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip ARG PIP_INDEX_URL="https://pypi.org/simple"
ENV PIP_INDEX_URL=$PIP_INDEX_URL
# clone latest code, and rename to /app # COPY only requirements.txt first to leverage Docker cache
RUN git clone https://github.com/csunny/DB-GPT.git /app COPY ./requirements.txt /tmp/requirements.txt
WORKDIR /app WORKDIR /app
RUN pip3 install --upgrade pip \ RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
&& pip3 install --no-cache-dir -r requirements.txt \ && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \
&& pip3 install seaborn mpld3 \ # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip
&& wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \
&& pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ && cp /app/requirements.txt /tmp/requirements.txt; \
&& rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ fi;) \
&& rm -rf `pip3 cache dir` && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \
&& rm /tmp/requirements.txt
# RUN python3 -m spacy download zh_core_web_sm RUN (if [ "${LANGUAGE}" = "zh" ]; \
# language is zh, download zh_core_web_sm from github
then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
&& pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \
&& rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl; \
# not zh, download directly
else python3 -m spacy download zh_core_web_sm; \
fi;) \
&& rm -rf `pip3 cache dir`
ARG BUILD_LOCAL_CODE="false"
# COPY the rest of the app
COPY . /tmp/app
# TODONeed to find a better way to determine whether to build docker image with local code.
RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \
then mv /tmp/app / && rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \
else rm -rf /tmp/app; \
fi;)
EXPOSE 5000 EXPOSE 5000

View File

@ -4,5 +4,72 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")" cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd) WORK_DIR=$(pwd)
BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
IMAGE_NAME="db-gpt" IMAGE_NAME="db-gpt"
docker build -f Dockerfile -t $IMAGE_NAME $WORK_DIR/../../ # zh: https://pypi.tuna.tsinghua.edu.cn/simple
PIP_INDEX_URL="https://pypi.org/simple"
# en or zh
LANGUAGE="en"
BUILD_LOCAL_CODE="false"
usage () {
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04] [--image-name db-gpt]"
echo " [-b|--base-image base image name] Base image name"
echo " [-n|--image-name image name] Current image name, default: db-gpt"
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
echo " [--language en or zh] You language, default: en"
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
echo " [-h|--help] Usage message"
}
while [[ $# -gt 0 ]]; do
key="$1"
case $key in
-b|--base-image)
BASE_IMAGE="$2"
shift # past argument
shift # past value
;;
-n|--image-name)
IMAGE_NAME="$2"
shift # past argument
shift # past value
;;
-i|--pip-index-url)
PIP_INDEX="$2"
shift
shift
;;
--language)
LANGUAGE="$2"
shift
shift
;;
--build-local-code)
BUILD_LOCAL_CODE="$2"
shift
shift
;;
-h|--help)
help="true"
shift
;;
*)
usage
exit 1
;;
esac
done
if [[ $help ]]; then
usage
exit 0
fi
docker build \
--build-arg BASE_IMAGE=$BASE_IMAGE \
--build-arg PIP_INDEX_URL=$PIP_INDEX_URL \
--build-arg LANGUAGE=$LANGUAGE \
--build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \
-f Dockerfile \
-t $IMAGE_NAME $WORK_DIR/../../

View File

@ -4,6 +4,11 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")" cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd) WORK_DIR=$(pwd)
bash $WORK_DIR/base/build_image.sh bash $WORK_DIR/base/build_image.sh "$@"
if [ 0 -ne $? ]; then
ehco "Error: build base image failed"
exit 1
fi
bash $WORK_DIR/allinone/build_image.sh bash $WORK_DIR/allinone/build_image.sh

View File

@ -247,7 +247,7 @@ def remove_color_codes(s: str) -> str:
return ansi_escape.sub("", s) return ansi_escape.sub("", s)
logger = Logger() logger: Logger = Logger()
def print_assistant_thoughts( def print_assistant_thoughts(

View File

@ -28,6 +28,9 @@ class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT""" We will support those model, which performance resemble ChatGPT"""
def use_fast_tokenizer(self) -> bool:
return False
def match(self, model_path: str): def match(self, model_path: str):
return True return True

View File

@ -1,44 +1,70 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import sys from typing import Optional, Dict
import warnings import dataclasses
from typing import Optional
import torch import torch
from pilot.configs.model_config import DEVICE from pilot.configs.model_config import DEVICE
from pilot.model.adapter import get_llm_model_adapter from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.compression import compress_module from pilot.model.compression import compress_module
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
from pilot.singleton import Singleton from pilot.singleton import Singleton
from pilot.utils import get_gpu_memory from pilot.utils import get_gpu_memory
from pilot.logs import logger
def raise_warning_for_incompatible_cpu_offloading_configuration( class ModelType:
device: str, load_8bit: bool, cpu_offloading: bool """ "Type of model"""
):
if cpu_offloading: HF = "huggingface"
if not load_8bit: LLAMA_CPP = "llama.cpp"
warnings.warn( # TODO, support more model type
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
"Use '--load-8bit' to enable 8-bit-quantization\n"
"Continuing without cpu-offloading enabled\n" @dataclasses.dataclass
class ModelParams:
device: str
model_name: str
model_path: str
model_type: Optional[str] = ModelType.HF
num_gpus: Optional[int] = None
max_gpu_memory: Optional[str] = None
cpu_offloading: Optional[bool] = False
load_8bit: Optional[bool] = True
load_4bit: Optional[bool] = False
# quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float)
quant_type: Optional[str] = "nf4"
# Nested quantization is activated through `use_double_quant``
use_double_quant: Optional[bool] = True
# "bfloat16", "float16", "float32"
compute_dtype: Optional[str] = None
debug: Optional[bool] = False
trust_remote_code: Optional[bool] = True
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)
def _check_quantization(model_params: ModelParams):
model_name = model_params.model_name.lower()
has_quantization = any([model_params.load_8bit or model_params.load_4bit])
if has_quantization:
if model_params.device != "cuda":
logger.warn(
"8-bit quantization and 4-bit quantization just supported by cuda"
) )
return False return False
if not "linux" in sys.platform: elif "chatglm" in model_name:
warnings.warn( if "int4" not in model_name:
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" logger.warn(
"Continuing without cpu-offloading enabled\n" "chatglm or chatglm2 not support quantization now, see: https://github.com/huggingface/transformers/issues/25228"
) )
return False return False
if device != "cuda": return has_quantization
warnings.warn(
"CPU-offloading is only enabled when using CUDA-devices\n"
"Continuing without cpu-offloading enabled\n"
)
return False
return cpu_offloading
class ModelLoader(metaclass=Singleton): class ModelLoader(metaclass=Singleton):
@ -51,9 +77,10 @@ class ModelLoader(metaclass=Singleton):
kwargs = {} kwargs = {}
def __init__(self, model_path) -> None: def __init__(self, model_path: str, model_name: str = None) -> None:
self.device = DEVICE self.device = DEVICE
self.model_path = model_path self.model_path = model_path
self.model_name = model_name
self.kwargs = { self.kwargs = {
"torch_dtype": torch.float16, "torch_dtype": torch.float16,
"device_map": "auto", "device_map": "auto",
@ -64,64 +91,213 @@ class ModelLoader(metaclass=Singleton):
self, self,
num_gpus, num_gpus,
load_8bit=False, load_8bit=False,
load_4bit=False,
debug=False, debug=False,
cpu_offloading=False, cpu_offloading=False,
max_gpu_memory: Optional[str] = None, max_gpu_memory: Optional[str] = None,
): ):
if self.device == "cpu": model_params = ModelParams(
kwargs = {"torch_dtype": torch.float32} device=self.device,
model_path=self.model_path,
model_name=self.model_name,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
cpu_offloading=cpu_offloading,
load_8bit=load_8bit,
load_4bit=load_4bit,
debug=debug,
)
elif self.device == "cuda": llm_adapter = get_llm_model_adapter(model_params.model_path)
kwargs = {"torch_dtype": torch.float16} return huggingface_loader(llm_adapter, model_params)
num_gpus = torch.cuda.device_count()
if num_gpus != 1:
kwargs["device_map"] = "auto"
# if max_gpu_memory is None:
# kwargs["device_map"] = "sequential"
available_gpu_memory = get_gpu_memory(num_gpus) def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
kwargs["max_memory"] = { device = model_params.device
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" max_memory = None
for i in range(num_gpus) if device == "cpu":
} kwargs = {"torch_dtype": torch.float32}
elif device == "cuda":
kwargs = {"torch_dtype": torch.float16}
num_gpus = torch.cuda.device_count()
available_gpu_memory = get_gpu_memory(num_gpus)
max_memory = {
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)
}
if num_gpus != 1:
kwargs["device_map"] = "auto"
kwargs["max_memory"] = max_memory
elif model_params.max_gpu_memory:
logger.info(
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
)
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
kwargs["max_memory"] = max_memory
logger.debug(f"max_memory: {max_memory}")
else: elif device == "mps":
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} kwargs = {"torch_dtype": torch.float16}
replace_llama_attn_with_non_inplace_operations()
else:
raise ValueError(f"Invalid device: {device}")
elif self.device == "mps": can_quantization = _check_quantization(model_params)
kwargs = kwargs = {"torch_dtype": torch.float16}
replace_llama_attn_with_non_inplace_operations() if can_quantization and (num_gpus > 1 or model_params.load_4bit):
if _check_multi_gpu_or_4bit_quantization(model_params):
return load_huggingface_quantization_model(
llm_adapter, model_params, kwargs, max_memory
)
else: else:
raise ValueError(f"Invalid device: {self.device}") logger.warn(
f"Current model {model_params.model_name} not supported quantization"
)
# default loader
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs)
# TODO when cpu loading, need use quantization config if model_params.load_8bit and num_gpus == 1 and tokenizer:
# TODO merge current code into `load_huggingface_quantization_model`
compress_module(model, model_params.device)
llm_adapter = get_llm_model_adapter(self.model_path) if (
model, tokenizer = llm_adapter.loader(self.model_path, kwargs) (device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
or device == "mps"
and tokenizer
):
try:
model.to(device)
except ValueError:
pass
except AttributeError:
pass
if model_params.debug:
print(model)
return model, tokenizer
if load_8bit and tokenizer:
if num_gpus != 1:
warnings.warn(
"8-bit quantization is not supported for multi-gpu inference"
)
else:
compress_module(model, self.device)
if ( def load_huggingface_quantization_model(
(self.device == "cuda" and num_gpus == 1 and not cpu_offloading) llm_adapter: BaseLLMAdaper,
or self.device == "mps" model_params: ModelParams,
and tokenizer kwargs: Dict,
): max_memory: Dict[int, str],
# 4-bit not support this ):
try: try:
model.to(self.device) from accelerate import init_empty_weights
except ValueError: from accelerate.utils import infer_auto_device_map
pass import transformers
except AttributeError: from transformers import (
pass BitsAndBytesConfig,
AutoConfig,
AutoModel,
AutoModelForCausalLM,
LlamaForCausalLM,
AutoModelForSeq2SeqLM,
LlamaTokenizer,
AutoTokenizer,
)
except ImportError as exc:
raise ValueError(
"Could not import depend python package "
"Please install it with `pip install transformers` "
"`pip install bitsandbytes``pip install accelerate`."
) from exc
if (
"llama-2" in model_params.model_name.lower()
and not transformers.__version__ >= "4.31.0"
):
raise ValueError(
"Llama-2 quantization require transformers.__version__>=4.31.0"
)
params = {"low_cpu_mem_usage": True}
params["low_cpu_mem_usage"] = True
params["device_map"] = "auto"
if debug: torch_dtype = kwargs.get("torch_dtype")
print(model)
return model, tokenizer if model_params.load_4bit:
compute_dtype = None
if model_params.compute_dtype and model_params.compute_dtype in [
"bfloat16",
"float16",
"float32",
]:
compute_dtype = eval("torch.{}".format(model_params.compute_dtype))
quantization_config_params = {
"load_in_4bit": True,
"bnb_4bit_compute_dtype": compute_dtype,
"bnb_4bit_quant_type": model_params.quant_type,
"bnb_4bit_use_double_quant": model_params.use_double_quant,
}
logger.warn(
"Using the following 4-bit params: " + str(quantization_config_params)
)
params["quantization_config"] = BitsAndBytesConfig(**quantization_config_params)
elif model_params.load_8bit and max_memory:
params["quantization_config"] = BitsAndBytesConfig(
load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True
)
elif model_params.load_in_8bit:
params["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
params["torch_dtype"] = torch_dtype if torch_dtype else torch.float16
params["max_memory"] = max_memory
if "chatglm" in model_params.model_name.lower():
LoaderClass = AutoModel
else:
config = AutoConfig.from_pretrained(
model_params.model_path, trust_remote_code=model_params.trust_remote_code
)
if config.to_dict().get("is_encoder_decoder", False):
LoaderClass = AutoModelForSeq2SeqLM
else:
LoaderClass = AutoModelForCausalLM
if model_params.load_8bit and max_memory is not None:
config = AutoConfig.from_pretrained(
model_params.model_path, trust_remote_code=model_params.trust_remote_code
)
with init_empty_weights():
model = LoaderClass.from_config(
config, trust_remote_code=model_params.trust_remote_code
)
model.tie_weights()
params["device_map"] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params["max_memory"].copy(),
no_split_module_classes=model._no_split_modules,
)
try:
if model_params.trust_remote_code:
params["trust_remote_code"] = True
logger.info(f"params: {params}")
model = LoaderClass.from_pretrained(model_params.model_path, **params)
except Exception as e:
logger.error(
f"Load quantization model failed, error: {str(e)}, params: {params}"
)
raise e
# Loading the tokenizer
if type(model) is LlamaForCausalLM:
tokenizer = LlamaTokenizer.from_pretrained(
model_params.model_path, clean_up_tokenization_spaces=True
)
# Leaving this here until the LLaMA tokenizer gets figured out.
# For some people this fixes things, for others it causes an error.
try:
tokenizer.eos_token_id = 2
tokenizer.bos_token_id = 1
tokenizer.pad_token_id = 0
except Exception as e:
logger.warn(f"{str(e)}")
else:
tokenizer = AutoTokenizer.from_pretrained(
model_params.model_path,
trust_remote_code=model_params.trust_remote_code,
use_fast=llm_adapter.use_fast_tokenizer(),
)
return model, tokenizer

View File

@ -37,7 +37,7 @@ class ModelWorker:
self.model_name = model_name or model_path.split("/")[-1] self.model_name = model_name or model_path.split("/")[-1]
self.device = device self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml = ModelLoader(model_path=model_path) self.ml = ModelLoader(model_path=model_path, model_name=self.model_name)
self.model, self.tokenizer = self.ml.loader( self.model, self.tokenizer = self.ml.loader(
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
) )

View File

@ -1,10 +1,8 @@
torch==2.0.0 torch==2.0.0
accelerate==0.16.0
aiohttp==3.8.4 aiohttp==3.8.4
aiosignal==1.3.1 aiosignal==1.3.1
async-timeout==4.0.2 async-timeout==4.0.2
attrs==22.2.0 attrs==22.2.0
bitsandbytes==0.39.0
cchardet==2.1.7 cchardet==2.1.7
chardet==5.1.0 chardet==5.1.0
contourpy==1.0.7 contourpy==1.0.7
@ -27,7 +25,7 @@ python-dateutil==2.8.2
pyyaml==6.0 pyyaml==6.0
tokenizers==0.13.2 tokenizers==0.13.2
tqdm==4.64.1 tqdm==4.64.1
transformers==4.30.0 transformers>=4.31.0
transformers_stream_generator transformers_stream_generator
timm==0.6.13 timm==0.6.13
spacy==3.5.3 spacy==3.5.3
@ -48,6 +46,9 @@ gradio-client==0.0.8
wandb wandb
llama-index==0.5.27 llama-index==0.5.27
bitsandbytes
accelerate>=0.20.3
unstructured==0.6.3 unstructured==0.6.3
grpcio==1.47.5 grpcio==1.47.5
gpt4all==0.3.0 gpt4all==0.3.0