mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-17 15:58:25 +00:00
Support 8-bit quantization and 4-bit quantization (#399)
Close #389 **Some extra features:** 1. Configure the maximum memory used by each GPU 2. Unified model loading entry(for huggingface, llama.cpp, etc. ) 3. Enhance docker image building
This commit is contained in:
commit
1388f33ddc
@ -27,6 +27,8 @@ MODEL_SERVER=http://127.0.0.1:8000
|
|||||||
LIMIT_MODEL_CONCURRENCY=5
|
LIMIT_MODEL_CONCURRENCY=5
|
||||||
MAX_POSITION_EMBEDDINGS=4096
|
MAX_POSITION_EMBEDDINGS=4096
|
||||||
QUANTIZE_QLORA=True
|
QUANTIZE_QLORA=True
|
||||||
|
QUANTIZE_8bit=True
|
||||||
|
# QUANTIZE_4bit=False
|
||||||
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
## SMART_LLM_MODEL - Smart language model (Default: vicuna-13b)
|
||||||
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
## FAST_LLM_MODEL - Fast language model (Default: chatglm-6b)
|
||||||
# SMART_LLM_MODEL=vicuna-13b
|
# SMART_LLM_MODEL=vicuna-13b
|
||||||
@ -125,11 +127,15 @@ PROXY_SERVER_URL=https://api.openai.com/v1/chat/completions
|
|||||||
BARD_PROXY_API_KEY={your-bard-token}
|
BARD_PROXY_API_KEY={your-bard-token}
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# ** SUMMARY_CONFIG
|
#** SUMMARY_CONFIG **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
SUMMARY_CONFIG=FAST
|
SUMMARY_CONFIG=FAST
|
||||||
|
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
# ** MUlti-GPU
|
#** MUlti-GPU **#
|
||||||
#*******************************************************************#
|
#*******************************************************************#
|
||||||
NUM_GPUS = 1
|
## See https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/
|
||||||
|
## If CUDA_VISIBLE_DEVICES is not configured, all available gpus will be used
|
||||||
|
# CUDA_VISIBLE_DEVICES=0
|
||||||
|
## You can configure the maximum memory used by each GPU.
|
||||||
|
# MAX_GPU_MEMORY=16Gib
|
||||||
|
@ -1,25 +1,48 @@
|
|||||||
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
|
ARG BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
|
||||||
|
|
||||||
|
FROM ${BASE_IMAGE}
|
||||||
|
ARG BASE_IMAGE
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y git python3 pip wget \
|
RUN apt-get update && apt-get install -y git python3 pip wget \
|
||||||
&& apt-get clean
|
&& apt-get clean
|
||||||
|
|
||||||
# download code from githu: https://github.com/csunny/DB-GPT
|
ARG BUILD_LOCAL_CODE="false"
|
||||||
# ENV DBGPT_VERSION="v0.3.3"
|
ARG LANGUAGE="en"
|
||||||
# RUN wget https://github.com/csunny/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip
|
ARG PIP_INDEX_URL="https://pypi.org/simple"
|
||||||
|
ENV PIP_INDEX_URL=$PIP_INDEX_URL
|
||||||
|
|
||||||
# clone latest code, and rename to /app
|
# COPY only requirements.txt first to leverage Docker cache
|
||||||
RUN git clone https://github.com/csunny/DB-GPT.git /app
|
COPY ./requirements.txt /tmp/requirements.txt
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN pip3 install --upgrade pip \
|
RUN pip3 install --upgrade pip -i $PIP_INDEX_URL \
|
||||||
&& pip3 install --no-cache-dir -r requirements.txt \
|
&& (if [ "${BUILD_LOCAL_CODE}" = "false" ]; \
|
||||||
&& pip3 install seaborn mpld3 \
|
# if not build local code, clone latest code from git, and rename to /app, TODO: download by version, like: https://github.com/eosphoros-ai/DB-GPT/archive/refs/tags/$DBGPT_VERSION.zip
|
||||||
&& wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
|
then git clone https://github.com/eosphoros-ai/DB-GPT.git /app \
|
||||||
&& pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
|
&& cp /app/requirements.txt /tmp/requirements.txt; \
|
||||||
&& rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
|
fi;) \
|
||||||
|
&& pip3 install -r /tmp/requirements.txt -i $PIP_INDEX_URL --no-cache-dir \
|
||||||
|
&& rm /tmp/requirements.txt
|
||||||
|
|
||||||
|
RUN (if [ "${LANGUAGE}" = "zh" ]; \
|
||||||
|
# language is zh, download zh_core_web_sm from github
|
||||||
|
then wget https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.5.0/zh_core_web_sm-3.5.0-py3-none-any.whl -O /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl \
|
||||||
|
&& pip3 install /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl -i $PIP_INDEX_URL \
|
||||||
|
&& rm /tmp/zh_core_web_sm-3.5.0-py3-none-any.whl; \
|
||||||
|
# not zh, download directly
|
||||||
|
else python3 -m spacy download zh_core_web_sm; \
|
||||||
|
fi;) \
|
||||||
&& rm -rf `pip3 cache dir`
|
&& rm -rf `pip3 cache dir`
|
||||||
|
|
||||||
# RUN python3 -m spacy download zh_core_web_sm
|
ARG BUILD_LOCAL_CODE="false"
|
||||||
|
# COPY the rest of the app
|
||||||
|
COPY . /tmp/app
|
||||||
|
|
||||||
|
# TODO:Need to find a better way to determine whether to build docker image with local code.
|
||||||
|
RUN (if [ "${BUILD_LOCAL_CODE}" = "true" ]; \
|
||||||
|
then mv /tmp/app / && rm -rf /app/logs && rm -rf /app/pilot/data && rm -rf /app/pilot/message; \
|
||||||
|
else rm -rf /tmp/app; \
|
||||||
|
fi;)
|
||||||
|
|
||||||
EXPOSE 5000
|
EXPOSE 5000
|
@ -4,5 +4,72 @@ SCRIPT_LOCATION=$0
|
|||||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||||
WORK_DIR=$(pwd)
|
WORK_DIR=$(pwd)
|
||||||
|
|
||||||
|
BASE_IMAGE="nvidia/cuda:11.8.0-devel-ubuntu22.04"
|
||||||
IMAGE_NAME="db-gpt"
|
IMAGE_NAME="db-gpt"
|
||||||
docker build -f Dockerfile -t $IMAGE_NAME $WORK_DIR/../../
|
# zh: https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
PIP_INDEX_URL="https://pypi.org/simple"
|
||||||
|
# en or zh
|
||||||
|
LANGUAGE="en"
|
||||||
|
BUILD_LOCAL_CODE="false"
|
||||||
|
|
||||||
|
usage () {
|
||||||
|
echo "USAGE: $0 [--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04] [--image-name db-gpt]"
|
||||||
|
echo " [-b|--base-image base image name] Base image name"
|
||||||
|
echo " [-n|--image-name image name] Current image name, default: db-gpt"
|
||||||
|
echo " [-i|--pip-index-url pip index url] Pip index url, default: https://pypi.org/simple"
|
||||||
|
echo " [--language en or zh] You language, default: en"
|
||||||
|
echo " [--build-local-code true or false] Whether to use the local project code to package the image, default: false"
|
||||||
|
echo " [-h|--help] Usage message"
|
||||||
|
}
|
||||||
|
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
key="$1"
|
||||||
|
case $key in
|
||||||
|
-b|--base-image)
|
||||||
|
BASE_IMAGE="$2"
|
||||||
|
shift # past argument
|
||||||
|
shift # past value
|
||||||
|
;;
|
||||||
|
-n|--image-name)
|
||||||
|
IMAGE_NAME="$2"
|
||||||
|
shift # past argument
|
||||||
|
shift # past value
|
||||||
|
;;
|
||||||
|
-i|--pip-index-url)
|
||||||
|
PIP_INDEX="$2"
|
||||||
|
shift
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--language)
|
||||||
|
LANGUAGE="$2"
|
||||||
|
shift
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--build-local-code)
|
||||||
|
BUILD_LOCAL_CODE="$2"
|
||||||
|
shift
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
help="true"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
usage
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ $help ]]; then
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
docker build \
|
||||||
|
--build-arg BASE_IMAGE=$BASE_IMAGE \
|
||||||
|
--build-arg PIP_INDEX_URL=$PIP_INDEX_URL \
|
||||||
|
--build-arg LANGUAGE=$LANGUAGE \
|
||||||
|
--build-arg BUILD_LOCAL_CODE=$BUILD_LOCAL_CODE \
|
||||||
|
-f Dockerfile \
|
||||||
|
-t $IMAGE_NAME $WORK_DIR/../../
|
||||||
|
@ -4,6 +4,11 @@ SCRIPT_LOCATION=$0
|
|||||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||||
WORK_DIR=$(pwd)
|
WORK_DIR=$(pwd)
|
||||||
|
|
||||||
bash $WORK_DIR/base/build_image.sh
|
bash $WORK_DIR/base/build_image.sh "$@"
|
||||||
|
|
||||||
|
if [ 0 -ne $? ]; then
|
||||||
|
ehco "Error: build base image failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
bash $WORK_DIR/allinone/build_image.sh
|
bash $WORK_DIR/allinone/build_image.sh
|
@ -80,26 +80,11 @@ Open http://localhost:5000 with your browser to see the product.
|
|||||||
If you want to access an external LLM service, you need to 1.set the variables LLM_MODEL=YOUR_MODEL_NAME MODEL_SERVER=YOUR_MODEL_SERVER(eg:http://localhost:5000) in the .env file.
|
If you want to access an external LLM service, you need to 1.set the variables LLM_MODEL=YOUR_MODEL_NAME MODEL_SERVER=YOUR_MODEL_SERVER(eg:http://localhost:5000) in the .env file.
|
||||||
2.execute dbgpt_server.py in light mode
|
2.execute dbgpt_server.py in light mode
|
||||||
|
|
||||||
|
If you want to learn about dbgpt-webui, read https://github./csunny/DB-GPT/tree/new-page-framework/datacenter
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ python pilot/server/dbgpt_server.py --light
|
$ python pilot/server/dbgpt_server.py --light
|
||||||
```
|
```
|
||||||
#### 3.1 Steps for Starting ChatGLM-6B and ChatGLM2-6B with Multiple Cards
|
|
||||||
|
|
||||||
Modify the. env.template or pilot/configurations/config.py file NUM_ Number of GPUS (quantity is the actual number of graphics cards required for startup)
|
|
||||||
|
|
||||||
At the same time, it is necessary to specify the required gpu card ID before starting the command (note that the number of gpu cards specified is consistent with the number of NUM_GPUS), as shown below:
|
|
||||||
|
|
||||||
````shell
|
|
||||||
# Specify 1 gpu card
|
|
||||||
NUM_GPUS = 1
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python3 pilot/server/dbgpt_server.py
|
|
||||||
|
|
||||||
# Specify 4 gpus card
|
|
||||||
NUM_GPUS = 4
|
|
||||||
CUDA_VISIBLE_DEVICES=3,4,5,6 python3 pilot/server/dbgpt_server.py
|
|
||||||
````
|
|
||||||
|
|
||||||
If you want to learn about dbgpt-webui, read https://github.com/csunny/DB-GPT/tree/new-page-framework/datacenter
|
|
||||||
|
|
||||||
### 4. Docker (Experimental)
|
### 4. Docker (Experimental)
|
||||||
|
|
||||||
@ -196,3 +181,28 @@ $ docker logs db-gpt-webserver-1 -f
|
|||||||
Open http://localhost:5000 with your browser to see the product.
|
Open http://localhost:5000 with your browser to see the product.
|
||||||
|
|
||||||
You can open docker-compose.yml in the project root directory to see more details.
|
You can open docker-compose.yml in the project root directory to see more details.
|
||||||
|
|
||||||
|
|
||||||
|
### 5. Multiple GPUs
|
||||||
|
|
||||||
|
DB-GPT will use all available gpu by default. And you can modify the setting `CUDA_VISIBLE_DEVICES=0,1` in `.env` file to use the specific gpu IDs.
|
||||||
|
|
||||||
|
Optionally, you can also specify the gpu ID to use before the starting command, as shown below:
|
||||||
|
|
||||||
|
````shell
|
||||||
|
# Specify 1 gpu
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python3 pilot/server/dbgpt_server.py
|
||||||
|
|
||||||
|
# Specify 4 gpus
|
||||||
|
CUDA_VISIBLE_DEVICES=3,4,5,6 python3 pilot/server/dbgpt_server.py
|
||||||
|
````
|
||||||
|
|
||||||
|
### 6. Not Enough Memory
|
||||||
|
|
||||||
|
DB-GPT supported 8-bit quantization and 4-bit quantization.
|
||||||
|
|
||||||
|
You can modify the setting `QUANTIZE_8bit=True` or `QUANTIZE_4bit=True` in `.env` file to use quantization(8-bit quantization is enabled by default).
|
||||||
|
|
||||||
|
Llama-2-70b with 8-bit quantization can run with 80 GB of VRAM, and 4-bit quantization can run with 48 GB of VRAM.
|
||||||
|
|
||||||
|
Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
|
@ -29,7 +29,7 @@ class Config(metaclass=Singleton):
|
|||||||
self.skip_reprompt = False
|
self.skip_reprompt = False
|
||||||
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
self.temperature = float(os.getenv("TEMPERATURE", 0.7))
|
||||||
|
|
||||||
self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
# self.NUM_GPUS = int(os.getenv("NUM_GPUS", 1))
|
||||||
|
|
||||||
self.execute_local_commands = (
|
self.execute_local_commands = (
|
||||||
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
os.getenv("EXECUTE_LOCAL_COMMANDS", "False") == "True"
|
||||||
@ -145,7 +145,6 @@ class Config(metaclass=Singleton):
|
|||||||
self.MODEL_SERVER = os.getenv(
|
self.MODEL_SERVER = os.getenv(
|
||||||
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
|
"MODEL_SERVER", "http://127.0.0.1" + ":" + str(self.MODEL_PORT)
|
||||||
)
|
)
|
||||||
self.ISLOAD_8BIT = os.getenv("ISLOAD_8BIT", "True") == "True"
|
|
||||||
|
|
||||||
### Vector Store Configuration
|
### Vector Store Configuration
|
||||||
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
self.VECTOR_STORE_TYPE = os.getenv("VECTOR_STORE_TYPE", "Chroma")
|
||||||
@ -156,6 +155,10 @@ class Config(metaclass=Singleton):
|
|||||||
|
|
||||||
# QLoRA
|
# QLoRA
|
||||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||||
|
self.IS_LOAD_8BIT = bool(os.getenv("QUANTIZE_8bit", "True"))
|
||||||
|
self.IS_LOAD_4BIT = bool(os.getenv("QUANTIZE_4bit", "False"))
|
||||||
|
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||||
|
self.IS_LOAD_8BIT = False
|
||||||
|
|
||||||
### EMBEDDING Configuration
|
### EMBEDDING Configuration
|
||||||
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
self.EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "text2vec")
|
||||||
@ -164,6 +167,8 @@ class Config(metaclass=Singleton):
|
|||||||
### SUMMARY_CONFIG Configuration
|
### SUMMARY_CONFIG Configuration
|
||||||
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
self.SUMMARY_CONFIG = os.getenv("SUMMARY_CONFIG", "FAST")
|
||||||
|
|
||||||
|
self.MAX_GPU_MEMORY = os.getenv("MAX_GPU_MEMORY", None)
|
||||||
|
|
||||||
def set_debug_mode(self, value: bool) -> None:
|
def set_debug_mode(self, value: bool) -> None:
|
||||||
"""Set the debug mode value"""
|
"""Set the debug mode value"""
|
||||||
self.debug_mode = value
|
self.debug_mode = value
|
||||||
|
@ -62,7 +62,6 @@ LLM_MODEL_CONFIG = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Load model config
|
# Load model config
|
||||||
ISLOAD_8BIT = True
|
|
||||||
ISDEBUG = False
|
ISDEBUG = False
|
||||||
|
|
||||||
VECTOR_SEARCH_TOP_K = 10
|
VECTOR_SEARCH_TOP_K = 10
|
||||||
|
@ -247,7 +247,7 @@ def remove_color_codes(s: str) -> str:
|
|||||||
return ansi_escape.sub("", s)
|
return ansi_escape.sub("", s)
|
||||||
|
|
||||||
|
|
||||||
logger = Logger()
|
logger: Logger = Logger()
|
||||||
|
|
||||||
|
|
||||||
def print_assistant_thoughts(
|
def print_assistant_thoughts(
|
||||||
|
@ -28,6 +28,9 @@ class BaseLLMAdaper:
|
|||||||
"""The Base class for multi model, in our project.
|
"""The Base class for multi model, in our project.
|
||||||
We will support those model, which performance resemble ChatGPT"""
|
We will support those model, which performance resemble ChatGPT"""
|
||||||
|
|
||||||
|
def use_fast_tokenizer(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -115,13 +118,7 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
|||||||
def match(self, model_path: str):
|
def match(self, model_path: str):
|
||||||
return "chatglm" in model_path
|
return "chatglm" in model_path
|
||||||
|
|
||||||
def loader(
|
def loader(self, model_path: str, from_pretrained_kwargs: dict):
|
||||||
self,
|
|
||||||
model_path: str,
|
|
||||||
from_pretrained_kwargs: dict,
|
|
||||||
device_map=None,
|
|
||||||
num_gpus=CFG.NUM_GPUS,
|
|
||||||
):
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||||
|
|
||||||
if DEVICE != "cuda":
|
if DEVICE != "cuda":
|
||||||
@ -130,6 +127,8 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
|||||||
).float()
|
).float()
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
else:
|
else:
|
||||||
|
device_map = None
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
model = (
|
model = (
|
||||||
AutoModel.from_pretrained(
|
AutoModel.from_pretrained(
|
||||||
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
model_path, trust_remote_code=True, **from_pretrained_kwargs
|
||||||
@ -138,9 +137,6 @@ class ChatGLMAdapater(BaseLLMAdaper):
|
|||||||
)
|
)
|
||||||
from accelerate import dispatch_model
|
from accelerate import dispatch_model
|
||||||
|
|
||||||
# model = AutoModel.from_pretrained(model_path, trust_remote_code=True,
|
|
||||||
# **from_pretrained_kwargs).half()
|
|
||||||
#
|
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = auto_configure_device_map(num_gpus)
|
device_map = auto_configure_device_map(num_gpus)
|
||||||
|
|
||||||
|
@ -1,44 +1,70 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
import sys
|
from typing import Optional, Dict
|
||||||
import warnings
|
import dataclasses
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pilot.configs.model_config import DEVICE
|
from pilot.configs.model_config import DEVICE
|
||||||
from pilot.model.adapter import get_llm_model_adapter
|
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||||
from pilot.model.compression import compress_module
|
from pilot.model.compression import compress_module
|
||||||
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
from pilot.model.llm.monkey_patch import replace_llama_attn_with_non_inplace_operations
|
||||||
from pilot.singleton import Singleton
|
from pilot.singleton import Singleton
|
||||||
from pilot.utils import get_gpu_memory
|
from pilot.utils import get_gpu_memory
|
||||||
|
from pilot.logs import logger
|
||||||
|
|
||||||
|
|
||||||
def raise_warning_for_incompatible_cpu_offloading_configuration(
|
class ModelType:
|
||||||
device: str, load_8bit: bool, cpu_offloading: bool
|
""" "Type of model"""
|
||||||
):
|
|
||||||
if cpu_offloading:
|
HF = "huggingface"
|
||||||
if not load_8bit:
|
LLAMA_CPP = "llama.cpp"
|
||||||
warnings.warn(
|
# TODO, support more model type
|
||||||
"The cpu-offloading feature can only be used while also using 8-bit-quantization.\n"
|
|
||||||
"Use '--load-8bit' to enable 8-bit-quantization\n"
|
|
||||||
"Continuing without cpu-offloading enabled\n"
|
@dataclasses.dataclass
|
||||||
|
class ModelParams:
|
||||||
|
device: str
|
||||||
|
model_name: str
|
||||||
|
model_path: str
|
||||||
|
model_type: Optional[str] = ModelType.HF
|
||||||
|
num_gpus: Optional[int] = None
|
||||||
|
max_gpu_memory: Optional[str] = None
|
||||||
|
cpu_offloading: Optional[bool] = False
|
||||||
|
load_8bit: Optional[bool] = True
|
||||||
|
load_4bit: Optional[bool] = False
|
||||||
|
# quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float)
|
||||||
|
quant_type: Optional[str] = "nf4"
|
||||||
|
# Nested quantization is activated through `use_double_quant``
|
||||||
|
use_double_quant: Optional[bool] = True
|
||||||
|
# "bfloat16", "float16", "float32"
|
||||||
|
compute_dtype: Optional[str] = None
|
||||||
|
debug: Optional[bool] = False
|
||||||
|
trust_remote_code: Optional[bool] = True
|
||||||
|
|
||||||
|
|
||||||
|
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
|
||||||
|
model_name = model_params.model_name.lower()
|
||||||
|
supported_models = ["llama", "baichuan", "vicuna"]
|
||||||
|
return any(m in model_name for m in supported_models)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_quantization(model_params: ModelParams):
|
||||||
|
model_name = model_params.model_name.lower()
|
||||||
|
has_quantization = any([model_params.load_8bit or model_params.load_4bit])
|
||||||
|
if has_quantization:
|
||||||
|
if model_params.device != "cuda":
|
||||||
|
logger.warn(
|
||||||
|
"8-bit quantization and 4-bit quantization just supported by cuda"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if not "linux" in sys.platform:
|
elif "chatglm" in model_name:
|
||||||
warnings.warn(
|
if "int4" not in model_name:
|
||||||
"CPU-offloading is only supported on linux-systems due to the limited compatability with the bitsandbytes-package\n"
|
logger.warn(
|
||||||
"Continuing without cpu-offloading enabled\n"
|
"chatglm or chatglm2 not support quantization now, see: https://github.com/huggingface/transformers/issues/25228"
|
||||||
)
|
)
|
||||||
return False
|
return False
|
||||||
if device != "cuda":
|
return has_quantization
|
||||||
warnings.warn(
|
|
||||||
"CPU-offloading is only enabled when using CUDA-devices\n"
|
|
||||||
"Continuing without cpu-offloading enabled\n"
|
|
||||||
)
|
|
||||||
return False
|
|
||||||
return cpu_offloading
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoader(metaclass=Singleton):
|
class ModelLoader(metaclass=Singleton):
|
||||||
@ -51,9 +77,10 @@ class ModelLoader(metaclass=Singleton):
|
|||||||
|
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
def __init__(self, model_path) -> None:
|
def __init__(self, model_path: str, model_name: str = None) -> None:
|
||||||
self.device = DEVICE
|
self.device = DEVICE
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
|
self.model_name = model_name
|
||||||
self.kwargs = {
|
self.kwargs = {
|
||||||
"torch_dtype": torch.float16,
|
"torch_dtype": torch.float16,
|
||||||
"device_map": "auto",
|
"device_map": "auto",
|
||||||
@ -64,64 +91,213 @@ class ModelLoader(metaclass=Singleton):
|
|||||||
self,
|
self,
|
||||||
num_gpus,
|
num_gpus,
|
||||||
load_8bit=False,
|
load_8bit=False,
|
||||||
|
load_4bit=False,
|
||||||
debug=False,
|
debug=False,
|
||||||
cpu_offloading=False,
|
cpu_offloading=False,
|
||||||
max_gpu_memory: Optional[str] = None,
|
max_gpu_memory: Optional[str] = None,
|
||||||
):
|
):
|
||||||
if self.device == "cpu":
|
model_params = ModelParams(
|
||||||
kwargs = {"torch_dtype": torch.float32}
|
device=self.device,
|
||||||
|
model_path=self.model_path,
|
||||||
|
model_name=self.model_name,
|
||||||
|
num_gpus=num_gpus,
|
||||||
|
max_gpu_memory=max_gpu_memory,
|
||||||
|
cpu_offloading=cpu_offloading,
|
||||||
|
load_8bit=load_8bit,
|
||||||
|
load_4bit=load_4bit,
|
||||||
|
debug=debug,
|
||||||
|
)
|
||||||
|
|
||||||
elif self.device == "cuda":
|
llm_adapter = get_llm_model_adapter(model_params.model_path)
|
||||||
|
return huggingface_loader(llm_adapter, model_params)
|
||||||
|
|
||||||
|
|
||||||
|
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
|
||||||
|
device = model_params.device
|
||||||
|
max_memory = None
|
||||||
|
if device == "cpu":
|
||||||
|
kwargs = {"torch_dtype": torch.float32}
|
||||||
|
elif device == "cuda":
|
||||||
kwargs = {"torch_dtype": torch.float16}
|
kwargs = {"torch_dtype": torch.float16}
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
|
available_gpu_memory = get_gpu_memory(num_gpus)
|
||||||
|
max_memory = {
|
||||||
|
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" for i in range(num_gpus)
|
||||||
|
}
|
||||||
if num_gpus != 1:
|
if num_gpus != 1:
|
||||||
kwargs["device_map"] = "auto"
|
kwargs["device_map"] = "auto"
|
||||||
# if max_gpu_memory is None:
|
kwargs["max_memory"] = max_memory
|
||||||
# kwargs["device_map"] = "sequential"
|
elif model_params.max_gpu_memory:
|
||||||
|
logger.info(
|
||||||
|
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
|
||||||
|
)
|
||||||
|
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
|
||||||
|
kwargs["max_memory"] = max_memory
|
||||||
|
logger.debug(f"max_memory: {max_memory}")
|
||||||
|
|
||||||
available_gpu_memory = get_gpu_memory(num_gpus)
|
elif device == "mps":
|
||||||
kwargs["max_memory"] = {
|
kwargs = {"torch_dtype": torch.float16}
|
||||||
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
|
||||||
for i in range(num_gpus)
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
|
|
||||||
|
|
||||||
elif self.device == "mps":
|
|
||||||
kwargs = kwargs = {"torch_dtype": torch.float16}
|
|
||||||
replace_llama_attn_with_non_inplace_operations()
|
replace_llama_attn_with_non_inplace_operations()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid device: {self.device}")
|
raise ValueError(f"Invalid device: {device}")
|
||||||
|
|
||||||
# TODO when cpu loading, need use quantization config
|
can_quantization = _check_quantization(model_params)
|
||||||
|
|
||||||
llm_adapter = get_llm_model_adapter(self.model_path)
|
if can_quantization and (num_gpus > 1 or model_params.load_4bit):
|
||||||
model, tokenizer = llm_adapter.loader(self.model_path, kwargs)
|
if _check_multi_gpu_or_4bit_quantization(model_params):
|
||||||
|
return load_huggingface_quantization_model(
|
||||||
if load_8bit and tokenizer:
|
llm_adapter, model_params, kwargs, max_memory
|
||||||
if num_gpus != 1:
|
|
||||||
warnings.warn(
|
|
||||||
"8-bit quantization is not supported for multi-gpu inference"
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
compress_module(model, self.device)
|
logger.warn(
|
||||||
|
f"Current model {model_params.model_name} not supported quantization"
|
||||||
|
)
|
||||||
|
# default loader
|
||||||
|
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs)
|
||||||
|
|
||||||
|
if model_params.load_8bit and num_gpus == 1 and tokenizer:
|
||||||
|
# TODO merge current code into `load_huggingface_quantization_model`
|
||||||
|
compress_module(model, model_params.device)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(self.device == "cuda" and num_gpus == 1 and not cpu_offloading)
|
(device == "cuda" and num_gpus == 1 and not model_params.cpu_offloading)
|
||||||
or self.device == "mps"
|
or device == "mps"
|
||||||
and tokenizer
|
and tokenizer
|
||||||
):
|
):
|
||||||
# 4-bit not support this
|
|
||||||
try:
|
try:
|
||||||
model.to(self.device)
|
model.to(device)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
|
if model_params.debug:
|
||||||
if debug:
|
|
||||||
print(model)
|
print(model)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_huggingface_quantization_model(
|
||||||
|
llm_adapter: BaseLLMAdaper,
|
||||||
|
model_params: ModelParams,
|
||||||
|
kwargs: Dict,
|
||||||
|
max_memory: Dict[int, str],
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from accelerate.utils import infer_auto_device_map
|
||||||
|
import transformers
|
||||||
|
from transformers import (
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
LlamaForCausalLM,
|
||||||
|
AutoModelForSeq2SeqLM,
|
||||||
|
LlamaTokenizer,
|
||||||
|
AutoTokenizer,
|
||||||
|
)
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import depend python package "
|
||||||
|
"Please install it with `pip install transformers` "
|
||||||
|
"`pip install bitsandbytes``pip install accelerate`."
|
||||||
|
) from exc
|
||||||
|
if (
|
||||||
|
"llama-2" in model_params.model_name.lower()
|
||||||
|
and not transformers.__version__ >= "4.31.0"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Llama-2 quantization require transformers.__version__>=4.31.0"
|
||||||
|
)
|
||||||
|
params = {"low_cpu_mem_usage": True}
|
||||||
|
params["low_cpu_mem_usage"] = True
|
||||||
|
params["device_map"] = "auto"
|
||||||
|
|
||||||
|
torch_dtype = kwargs.get("torch_dtype")
|
||||||
|
|
||||||
|
if model_params.load_4bit:
|
||||||
|
compute_dtype = None
|
||||||
|
if model_params.compute_dtype and model_params.compute_dtype in [
|
||||||
|
"bfloat16",
|
||||||
|
"float16",
|
||||||
|
"float32",
|
||||||
|
]:
|
||||||
|
compute_dtype = eval("torch.{}".format(model_params.compute_dtype))
|
||||||
|
|
||||||
|
quantization_config_params = {
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"bnb_4bit_compute_dtype": compute_dtype,
|
||||||
|
"bnb_4bit_quant_type": model_params.quant_type,
|
||||||
|
"bnb_4bit_use_double_quant": model_params.use_double_quant,
|
||||||
|
}
|
||||||
|
logger.warn(
|
||||||
|
"Using the following 4-bit params: " + str(quantization_config_params)
|
||||||
|
)
|
||||||
|
params["quantization_config"] = BitsAndBytesConfig(**quantization_config_params)
|
||||||
|
elif model_params.load_8bit and max_memory:
|
||||||
|
params["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True
|
||||||
|
)
|
||||||
|
elif model_params.load_in_8bit:
|
||||||
|
params["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||||
|
params["torch_dtype"] = torch_dtype if torch_dtype else torch.float16
|
||||||
|
params["max_memory"] = max_memory
|
||||||
|
|
||||||
|
if "chatglm" in model_params.model_name.lower():
|
||||||
|
LoaderClass = AutoModel
|
||||||
|
else:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_params.model_path, trust_remote_code=model_params.trust_remote_code
|
||||||
|
)
|
||||||
|
if config.to_dict().get("is_encoder_decoder", False):
|
||||||
|
LoaderClass = AutoModelForSeq2SeqLM
|
||||||
|
else:
|
||||||
|
LoaderClass = AutoModelForCausalLM
|
||||||
|
|
||||||
|
if model_params.load_8bit and max_memory is not None:
|
||||||
|
config = AutoConfig.from_pretrained(
|
||||||
|
model_params.model_path, trust_remote_code=model_params.trust_remote_code
|
||||||
|
)
|
||||||
|
with init_empty_weights():
|
||||||
|
model = LoaderClass.from_config(
|
||||||
|
config, trust_remote_code=model_params.trust_remote_code
|
||||||
|
)
|
||||||
|
|
||||||
|
model.tie_weights()
|
||||||
|
params["device_map"] = infer_auto_device_map(
|
||||||
|
model,
|
||||||
|
dtype=torch.int8,
|
||||||
|
max_memory=params["max_memory"].copy(),
|
||||||
|
no_split_module_classes=model._no_split_modules,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if model_params.trust_remote_code:
|
||||||
|
params["trust_remote_code"] = True
|
||||||
|
logger.info(f"params: {params}")
|
||||||
|
model = LoaderClass.from_pretrained(model_params.model_path, **params)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Load quantization model failed, error: {str(e)}, params: {params}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Loading the tokenizer
|
||||||
|
if type(model) is LlamaForCausalLM:
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
|
model_params.model_path, clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
|
# Leaving this here until the LLaMA tokenizer gets figured out.
|
||||||
|
# For some people this fixes things, for others it causes an error.
|
||||||
|
try:
|
||||||
|
tokenizer.eos_token_id = 2
|
||||||
|
tokenizer.bos_token_id = 1
|
||||||
|
tokenizer.pad_token_id = 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.warn(f"{str(e)}")
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_params.model_path,
|
||||||
|
trust_remote_code=model_params.trust_remote_code,
|
||||||
|
use_fast=llm_adapter.use_fast_tokenizer(),
|
||||||
|
)
|
||||||
|
|
||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
@ -37,9 +37,13 @@ class ModelWorker:
|
|||||||
self.model_name = model_name or model_path.split("/")[-1]
|
self.model_name = model_name or model_path.split("/")[-1]
|
||||||
self.device = device
|
self.device = device
|
||||||
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||||
self.ml = ModelLoader(model_path=model_path)
|
self.ml = ModelLoader(model_path=model_path, model_name=self.model_name)
|
||||||
self.model, self.tokenizer = self.ml.loader(
|
self.model, self.tokenizer = self.ml.loader(
|
||||||
num_gpus, load_8bit=ISLOAD_8BIT, debug=ISDEBUG
|
num_gpus,
|
||||||
|
load_8bit=CFG.IS_LOAD_8BIT,
|
||||||
|
load_4bit=CFG.IS_LOAD_4BIT,
|
||||||
|
debug=ISDEBUG,
|
||||||
|
max_gpu_memory=CFG.MAX_GPU_MEMORY,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(self.model, str):
|
if not isinstance(self.model, str):
|
||||||
|
@ -1,10 +1,8 @@
|
|||||||
torch==2.0.0
|
torch==2.0.0
|
||||||
accelerate==0.16.0
|
|
||||||
aiohttp==3.8.4
|
aiohttp==3.8.4
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
async-timeout==4.0.2
|
async-timeout==4.0.2
|
||||||
attrs==22.2.0
|
attrs==22.2.0
|
||||||
bitsandbytes==0.39.0
|
|
||||||
cchardet==2.1.7
|
cchardet==2.1.7
|
||||||
chardet==5.1.0
|
chardet==5.1.0
|
||||||
contourpy==1.0.7
|
contourpy==1.0.7
|
||||||
@ -27,7 +25,7 @@ python-dateutil==2.8.2
|
|||||||
pyyaml==6.0
|
pyyaml==6.0
|
||||||
tokenizers==0.13.2
|
tokenizers==0.13.2
|
||||||
tqdm==4.64.1
|
tqdm==4.64.1
|
||||||
transformers==4.30.0
|
transformers>=4.31.0
|
||||||
transformers_stream_generator
|
transformers_stream_generator
|
||||||
timm==0.6.13
|
timm==0.6.13
|
||||||
spacy==3.5.3
|
spacy==3.5.3
|
||||||
@ -48,6 +46,9 @@ gradio-client==0.0.8
|
|||||||
wandb
|
wandb
|
||||||
llama-index==0.5.27
|
llama-index==0.5.27
|
||||||
|
|
||||||
|
bitsandbytes
|
||||||
|
accelerate>=0.20.3
|
||||||
|
|
||||||
unstructured==0.6.3
|
unstructured==0.6.3
|
||||||
grpcio==1.47.5
|
grpcio==1.47.5
|
||||||
gpt4all==0.3.0
|
gpt4all==0.3.0
|
||||||
|
Loading…
Reference in New Issue
Block a user