From d8a4b776d512c2bfbd38870caa29603d6619da32 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Wed, 2 Aug 2023 15:51:57 +0800 Subject: [PATCH] feat: Support 8-bit quantization and 4-bit quantization for multi-gpu inference --- docker/base/Dockerfile | 51 ++++-- docker/base/build_image.sh | 69 +++++++- docker/build_all_images.sh | 7 +- pilot/logs.py | 2 +- pilot/model/adapter.py | 3 + pilot/model/loader.py | 320 ++++++++++++++++++++++++++++--------- pilot/server/llmserver.py | 2 +- requirements.txt | 7 +- 8 files changed, 368 insertions(+), 93 deletions(-) diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index 1b9274fea..5144c1348 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -1,25 +1,48 @@ -FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 +ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04" + +FROM ${BASE_IMAGE} +ARG BASE_IMAGE RUN apt-get update && apt-get install -y git python3 pip wget \ && apt-get clean -# download code from githu: https://github.com/csunny/DB-GPT -# ENV DBGPT_VERSION="v0.3.3" -# RUN wget https://github.com/csunny/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip +ARG BUILD_LOCAL_CODE="false" +ARG LANGUAGE="en" +ARG PIP_INDEX_URL="https://pypi.org/simple" +ENV PIP_INDEX_URL=$PIP_INDEX_URL -# clone latest code, and rename to /app -RUN git clone https://github.com/csunny/DB-GPT.git /app +# COPY only requirements.txt first to leverage Docker cache +COPY ./requirements.txt /tmp/requirements.txt WORKDIR /app -RUN pip3 install --upgrade pip \ - && pip3 install --no-cache-dir -r requirements.txt \ - && pip3 install seaborn mpld3 \ - && wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ - && pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ - && rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ - && rm -rf `pip3 cache dir` +RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \ + && (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \ + # if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip + then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \ + && cp /app/requirements.txt /tmp/requirements.txt; \ + fi;) \ + && pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \ + && rm /tmp/requirements.txt -# RUN python3 -m spacy download zh_core_web_sm +RUN (if [ "${LANGUAGE}" = "zh" ]; \ + # language is zh, download zh_core_web_sm from github + then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \ + && pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \ + && rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl; \ + # not zh, download directly + else python3 -m spacy download zh_core_web_sm; \ + fi;) \ + && rm -rf `pip3 cache dir` + +ARG BUILD_LOCAL_CODE="false" +# COPY the rest of the app +COPY . /tmp/app + +# TODO:Need to find a better way to determine whether to build docker image with local code. +RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \ + then mv /tmp/app / && rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \ + else rm -rf /tmp/app; \ + fi;) EXPOSE 5000 \ No newline at end of file diff --git a/docker/base/build_image.sh b/docker/base/build_image.sh index 7c4be28bf..3c9af8c8e 100755 --- a/docker/base/build_image.sh +++ b/docker/base/build_image.sh @@ -4,5 +4,72 @@ SCRIPT_LOCATION=$0 cd "$(dirname "$SCRIPT_LOCATION")" WORK_DIR=$(pwd) +BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04" IMAGE_NAME="db-gpt" -docker build -f Dockerfile -t $IMAGE_NAME $WORK_DIR/../../ \ No newline at end of file +# zh: https://pypi.tuna.tsinghua.edu.cn/simple +PIP_INDEX_URL="https://pypi.org/simple" +# en or zh +LANGUAGE="en" +BUILD_LOCAL_CODE="false" + +usage () { + echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04] [--image-name db-gpt]" + echo " [-b|--base-image base image name] Base image name" + echo " [-n|--image-name image name] Current image name, default: db-gpt" + echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple" + echo " [--language en or zh] You language, default: en" + echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false" + echo " [-h|--help] Usage message" +} + +while [[ $# -gt 0 ]]; do + key="$1" + case $key in + -b|--base-image) + BASE_IMAGE="$2" + shift # past argument + shift # past value + ;; + -n|--image-name) + IMAGE_NAME="$2" + shift # past argument + shift # past value + ;; + -i|--pip-index-url) + PIP_INDEX="$2" + shift + shift + ;; + --language) + LANGUAGE="$2" + shift + shift + ;; + --build-local-code) + BUILD_LOCAL_CODE="$2" + shift + shift + ;; + -h|--help) + help="true" + shift + ;; + *) + usage + exit 1 + ;; + esac +done + +if [[ $help ]]; then + usage + exit 0 +fi + +docker build \ + --build-arg BASE_IMAGE=$BASE_IMAGE \ + --build-arg PIP_INDEX_URL=$PIP_INDEX_URL \ + --build-arg LANGUAGE=$LANGUAGE \ + --build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \ + -f Dockerfile \ + -t $IMAGE_NAME $WORK_DIR/../../ diff --git a/docker/build_all_images.sh b/docker/build_all_images.sh index a49015c26..ec307f1fe 100755 --- a/docker/build_all_images.sh +++ b/docker/build_all_images.sh @@ -4,6 +4,11 @@ SCRIPT_LOCATION=$0 cd "$(dirname "$SCRIPT_LOCATION")" WORK_DIR=$(pwd) -bash $WORK_DIR/base/build_image.sh +bash $WORK_DIR/base/build_image.sh "$@" + +if [ 0 -ne $? ]; then + ehco "Error: build base image failed" + exit 1 +fi bash $WORK_DIR/allinone/build_image.sh \ No newline at end of file diff --git a/pilot/logs.py b/pilot/logs.py index 52d25b5fd..cf5d5603f 100644 --- a/pilot/logs.py +++ b/pilot/logs.py @@ -247,7 +247,7 @@ def remove_color_codes(s: str) -> str: return ansi_escape.sub("", s) -logger = Logger() +logger: Logger = Logger() def print_assistant_thoughts( diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 900d51d4a..26cbe9371 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -28,6 +28,9 @@ class BaseLLMAdaper: """The Base class for multi model, in our project. We will support those model, which performance resemble ChatGPT""" + def use_fast_tokenizer(self) -> bool: + return False + def match(self, model_path: str): return True diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 6acbc9234..d56ef6a47 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -1,44 +1,70 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -import sys -import warnings -from typing import Optional - +from typing import Optional, Dict +import dataclasses import torch from pilot.configs.model_config import DEVICE -from pilot.model.adapter import get_llm_model_adapter +from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper from pilot.model.compression import compress_module from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations from pilot.singleton import Singleton from pilot.utils import get_gpu_memory +from pilot.logs import logger -def raise_warning_for_incompatible_cpu_offloading_configuration( - device: str, load_8bit: bool, cpu_offloading: bool -): - if cpu_offloading: - if not load_8bit: - warnings.warn( - "The cpu-offloading feature can only be used while also using 8-bit-quantization.\n" - "Use '--load-8bit' to enable 8-bit-quantization\n" - "Continuing without cpu-offloading enabled\n" +class ModelType: + """ "Type of model""" + + HF = "huggingface" + LLAMA_CPP = "llama.cpp" + # TODO, support more model type + + +@dataclasses.dataclass +class ModelParams: + device: str + model_name: str + model_path: str + model_type: Optional[str] = ModelType.HF + num_gpus: Optional[int] = None + max_gpu_memory: Optional[str] = None + cpu_offloading: Optional[bool] = False + load_8bit: Optional[bool] = True + load_4bit: Optional[bool] = False + # quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float) + quant_type: Optional[str] = "nf4" + # Nested quantization is activated through `use_double_quant`` + use_double_quant: Optional[bool] = True + # "bfloat16", "float16", "float32" + compute_dtype: Optional[str] = None + debug: Optional[bool] = False + trust_remote_code: Optional[bool] = True + + +def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams): + model_name = model_params.model_name.lower() + supported_models = ["llama", "baichuan", "vicuna"] + return any(m in model_name for m in supported_models) + + +def _check_quantization(model_params: ModelParams): + model_name = model_params.model_name.lower() + has_quantization = any([model_params.load_8bit or model_params.load_4bit]) + if has_quantization: + if model_params.device != "cuda": + logger.warn( + "8-bit quantization and 4-bit quantization just supported by cuda" ) return False - if not "linux" in sys.platform: - warnings.warn( - "CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n" - "Continuing without cpu-offloading enabled\n" - ) + elif "chatglm" in model_name: + if "int4" not in model_name: + logger.warn( + "chatglm or chatglm2 not support quantization now, see: https://github.com/huggingface/transformers/issues/25228" + ) return False - if device != "cuda": - warnings.warn( - "CPU-offloading is only enabled when using CUDA-devices\n" - "Continuing without cpu-offloading enabled\n" - ) - return False - return cpu_offloading + return has_quantization class ModelLoader(metaclass=Singleton): @@ -51,9 +77,10 @@ class ModelLoader(metaclass=Singleton): kwargs = {} - def __init__(self, model_path) -> None: + def __init__(self, model_path: str, model_name: str = None) -> None: self.device = DEVICE self.model_path = model_path + self.model_name = model_name self.kwargs = { "torch_dtype": torch.float16, "device_map": "auto", @@ -64,64 +91,213 @@ class ModelLoader(metaclass=Singleton): self, num_gpus, load_8bit=False, + load_4bit=False, debug=False, cpu_offloading=False, max_gpu_memory: Optional[str] = None, ): - if self.device == "cpu": - kwargs = {"torch_dtype": torch.float32} + model_params = ModelParams( + device=self.device, + model_path=self.model_path, + model_name=self.model_name, + num_gpus=num_gpus, + max_gpu_memory=max_gpu_memory, + cpu_offloading=cpu_offloading, + load_8bit=load_8bit, + load_4bit=load_4bit, + debug=debug, + ) - elif self.device == "cuda": - kwargs = {"torch_dtype": torch.float16} - num_gpus = torch.cuda.device_count() + llm_adapter = get_llm_model_adapter(model_params.model_path) + return huggingface_loader(llm_adapter, model_params) - if num_gpus != 1: - kwargs["device_map"] = "auto" - # if max_gpu_memory is None: - # kwargs["device_map"] = "sequential" - available_gpu_memory = get_gpu_memory(num_gpus) - kwargs["max_memory"] = { - i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" - for i in range(num_gpus) - } +def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams): + device = model_params.device + max_memory = None + if device == "cpu": + kwargs = {"torch_dtype": torch.float32} + elif device == "cuda": + kwargs = {"torch_dtype": torch.float16} + num_gpus = torch.cuda.device_count() + available_gpu_memory = get_gpu_memory(num_gpus) + max_memory = { + i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus) + } + if num_gpus != 1: + kwargs["device_map"] = "auto" + kwargs["max_memory"] = max_memory + elif model_params.max_gpu_memory: + logger.info( + f"There has max_gpu_memory from config: {model_params.max_gpu_memory}" + ) + max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)} + kwargs["max_memory"] = max_memory + logger.debug(f"max_memory: {max_memory}") - else: - kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} + replace_llama_attn_with_non_inplace_operations() + else: + raise ValueError(f"Invalid device: {device}") - elif self.device == "mps": - kwargs = kwargs = {"torch_dtype": torch.float16} - replace_llama_attn_with_non_inplace_operations() + can_quantization = _check_quantization(model_params) + + if can_quantization and (num_gpus > 1 or model_params.load_4bit): + if _check_multi_gpu_or_4bit_quantization(model_params): + return load_huggingface_quantization_model( + llm_adapter, model_params, kwargs, max_memory + ) else: - raise ValueError(f"Invalid device: {self.device}") + logger.warn( + f"Current model {model_params.model_name} not supported quantization" + ) + # default loader + model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs) - # TODO when cpu loading, need use quantization config + if model_params.load_8bit and num_gpus == 1 and tokenizer: + # TODO merge current code into `load_huggingface_quantization_model` + compress_module(model, model_params.device) - llm_adapter = get_llm_model_adapter(self.model_path) - model, tokenizer = llm_adapter.loader(self.model_path, kwargs) + if ( + (device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading) + or device == "mps" + and tokenizer + ): + try: + model.to(device) + except ValueError: + pass + except AttributeError: + pass + if model_params.debug: + print(model) + return model, tokenizer - if load_8bit and tokenizer: - if num_gpus != 1: - warnings.warn( - "8-bit quantization is not supported for multi-gpu inference" - ) - else: - compress_module(model, self.device) - if ( - (self.device == "cuda" and num_gpus == 1 and not cpu_offloading) - or self.device == "mps" - and tokenizer - ): - # 4-bit not support this - try: - model.to(self.device) - except ValueError: - pass - except AttributeError: - pass +def load_huggingface_quantization_model( + llm_adapter: BaseLLMAdaper, + model_params: ModelParams, + kwargs: Dict, + max_memory: Dict[int, str], +): + try: + from accelerate import init_empty_weights + from accelerate.utils import infer_auto_device_map + import transformers + from transformers import ( + BitsAndBytesConfig, + AutoConfig, + AutoModel, + AutoModelForCausalLM, + LlamaForCausalLM, + AutoModelForSeq2SeqLM, + LlamaTokenizer, + AutoTokenizer, + ) + except ImportError as exc: + raise ValueError( + "Could not import depend python package " + "Please install it with `pip install transformers` " + "`pip install bitsandbytes``pip install accelerate`." + ) from exc + if ( + "llama-2" in model_params.model_name.lower() + and not transformers.__version__ >= "4.31.0" + ): + raise ValueError( + "Llama-2 quantization require transformers.__version__>=4.31.0" + ) + params = {"low_cpu_mem_usage": True} + params["low_cpu_mem_usage"] = True + params["device_map"] = "auto" - if debug: - print(model) + torch_dtype = kwargs.get("torch_dtype") - return model, tokenizer + if model_params.load_4bit: + compute_dtype = None + if model_params.compute_dtype and model_params.compute_dtype in [ + "bfloat16", + "float16", + "float32", + ]: + compute_dtype = eval("torch.{}".format(model_params.compute_dtype)) + + quantization_config_params = { + "load_in_4bit": True, + "bnb_4bit_compute_dtype": compute_dtype, + "bnb_4bit_quant_type": model_params.quant_type, + "bnb_4bit_use_double_quant": model_params.use_double_quant, + } + logger.warn( + "Using the following 4-bit params: " + str(quantization_config_params) + ) + params["quantization_config"] = BitsAndBytesConfig(**quantization_config_params) + elif model_params.load_8bit and max_memory: + params["quantization_config"] = BitsAndBytesConfig( + load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True + ) + elif model_params.load_in_8bit: + params["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + params["torch_dtype"] = torch_dtype if torch_dtype else torch.float16 + params["max_memory"] = max_memory + + if "chatglm" in model_params.model_name.lower(): + LoaderClass = AutoModel + else: + config = AutoConfig.from_pretrained( + model_params.model_path, trust_remote_code=model_params.trust_remote_code + ) + if config.to_dict().get("is_encoder_decoder", False): + LoaderClass = AutoModelForSeq2SeqLM + else: + LoaderClass = AutoModelForCausalLM + + if model_params.load_8bit and max_memory is not None: + config = AutoConfig.from_pretrained( + model_params.model_path, trust_remote_code=model_params.trust_remote_code + ) + with init_empty_weights(): + model = LoaderClass.from_config( + config, trust_remote_code=model_params.trust_remote_code + ) + + model.tie_weights() + params["device_map"] = infer_auto_device_map( + model, + dtype=torch.int8, + max_memory=params["max_memory"].copy(), + no_split_module_classes=model._no_split_modules, + ) + try: + if model_params.trust_remote_code: + params["trust_remote_code"] = True + logger.info(f"params: {params}") + model = LoaderClass.from_pretrained(model_params.model_path, **params) + except Exception as e: + logger.error( + f"Load quantization model failed, error: {str(e)}, params: {params}" + ) + raise e + + # Loading the tokenizer + if type(model) is LlamaForCausalLM: + tokenizer = LlamaTokenizer.from_pretrained( + model_params.model_path, clean_up_tokenization_spaces=True + ) + # Leaving this here until the LLaMA tokenizer gets figured out. + # For some people this fixes things, for others it causes an error. + try: + tokenizer.eos_token_id = 2 + tokenizer.bos_token_id = 1 + tokenizer.pad_token_id = 0 + except Exception as e: + logger.warn(f"{str(e)}") + else: + tokenizer = AutoTokenizer.from_pretrained( + model_params.model_path, + trust_remote_code=model_params.trust_remote_code, + use_fast=llm_adapter.use_fast_tokenizer(), + ) + + return model, tokenizer diff --git a/pilot/server/llmserver.py b/pilot/server/llmserver.py index 9a34f7685..b02dc8525 100644 --- a/pilot/server/llmserver.py +++ b/pilot/server/llmserver.py @@ -37,7 +37,7 @@ class ModelWorker: self.model_name = model_name or model_path.split("/")[-1] self.device = device print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......") - self.ml = ModelLoader(model_path=model_path) + self.ml = ModelLoader(model_path=model_path, model_name=self.model_name) self.model, self.tokenizer = self.ml.loader( num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG ) diff --git a/requirements.txt b/requirements.txt index 78b5b9f1e..008560eb3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ torch==2.0.0 -accelerate==0.16.0 aiohttp==3.8.4 aiosignal==1.3.1 async-timeout==4.0.2 attrs==22.2.0 -bitsandbytes==0.39.0 cchardet==2.1.7 chardet==5.1.0 contourpy==1.0.7 @@ -27,7 +25,7 @@ python-dateutil==2.8.2 pyyaml==6.0 tokenizers==0.13.2 tqdm==4.64.1 -transformers==4.30.0 +transformers>=4.31.0 transformers_stream_generator timm==0.6.13 spacy==3.5.3 @@ -48,6 +46,9 @@ gradio-client==0.0.8 wandb llama-index==0.5.27 +bitsandbytes +accelerate>=0.20.3 + unstructured==0.6.3 grpcio==1.47.5 gpt4all==0.3.0