feat: Support vicuna-v1.5 and WizardLM-v1.2

This commit is contained in:
FangYin Cheng 2023-08-03 14:13:50 +08:00
parent 1388f33ddc
commit a4574aa614
11 changed files with 140 additions and 49 deletions

View File

@ -82,7 +82,9 @@ Currently, we have released multiple key features, which are listed below to dem
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
- Multi LLMs Support, Supports multiple large language models, currently supporting
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)
- Vicuna (7b,13b)
- ChatGLM-6b (int4,int8)
- ChatGLM2-6b (int4,int8)

View File

@ -112,12 +112,15 @@ https://github.com/csunny/DB-GPT/assets/13723926/55f31781-1d49-4757-b96e-7ef6d3d
- 多模型支持
- 支持多种大语言模型, 当前已支持如下模型:
- Vicuna(7b,13b)
- ChatGLM-6b(int4,int8)
- guanaco(7b,13b,33b)
- Gorilla(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- baichuan(7b,13b)
- 🔥 Vicuna-v1.5(7b,13b)
- 🔥 llama-2(7b,13b,70b)
- WizardLM-v1.2(13b)
- Vicuna (7b,13b)
- ChatGLM-6b (int4,int8)
- ChatGLM2-6b (int4,int8)
- guanaco(7b,13b,33b)
- Gorilla(7b,13b)
- baichuan(7b,13b)
## 架构方案
DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运行环境,并提供 vicuna 作为基础的大语言模型。此外我们通过LangChain提供私域知识库问答能力。同时我们支持插件模式, 在设计上原生支持Auto-GPT插件。我们的愿景是让围绕数据库和LLM构建应用程序更加简便和便捷。

View File

@ -4,10 +4,15 @@ SCRIPT_LOCATION=$0
cd "$(dirname "$SCRIPT_LOCATION")"
WORK_DIR=$(pwd)
if [[ " $* " == *" --help "* ]] || [[ " $* " == *" -h "* ]]; then
bash $WORK_DIR/base/build_image.sh "$@"
exit 0
fi
bash $WORK_DIR/base/build_image.sh "$@"
if [ 0 -ne $? ]; then
ehco "Error: build base image failed"
echo "Error: build base image failed"
exit 1
fi

View File

@ -48,6 +48,7 @@ Notice make sure you have install git-lfs
```
```bash
git clone https://huggingface.co/lmsys/vicuna-13b-v1.5
git clone https://huggingface.co/Tribbiani/vicuna-13b
git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
@ -62,6 +63,8 @@ cp .env.template .env
You can configure basic parameters in the .env file, for example setting LLM_MODEL to the model to be used
([Vicuna-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5) based on llama-2 has been released, we recommend you set `LLM_MODEL=vicuna-13b-v1.5` to try this model)
### 3. Run
You can refer to this document to obtain the Vicuna weights: [Vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights) .
@ -107,6 +110,16 @@ db-gpt-allinone latest e1ffd20b85ac 45 minutes ago 14.5GB
db-gpt latest e36fb0cca5d9 3 hours ago 14GB
```
You can pass some parameters to docker/build_all_images.sh.
```bash
$ bash docker/build_all_images.sh \
--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04 \
--pip-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
--language zh
```
You can execute the command `bash docker/build_all_images.sh --help` to see more usage.
#### 4.2. Run all in one docker container
**Run with local model**
@ -158,7 +171,7 @@ $ docker run --gpus "device=0" -d -p 3306:3306 \
- `-e LLM_MODEL=proxyllm`, means we use proxy llm(openai interface, fastchat interface...)
- `-v /data/models/text2vec-large-chinese:/app/models/text2vec-large-chinese`, means we mount the local text2vec model to the docker container.
#### 4.2. Run with docker compose
#### 4.3. Run with docker compose
```bash
$ docker compose up -d
@ -197,6 +210,8 @@ CUDA_VISIBLE_DEVICES=0 python3 pilot/server/dbgpt_server.py
CUDA_VISIBLE_DEVICES=3,4,5,6 python3 pilot/server/dbgpt_server.py
````
You can modify the setting `MAX_GPU_MEMORY=xxGib` in `.env` file to configure the maximum memory used by each GPU.
### 6. Not Enough Memory
DB-GPT supported 8-bit quantization and 4-bit quantization.
@ -205,4 +220,24 @@ You can modify the setting `QUANTIZE_8bit=True` or `QUANTIZE_4bit=True` in `.env
Llama-2-70b with 8-bit quantization can run with 80 GB of VRAM, and 4-bit quantization can run with 48 GB of VRAM.
Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
Here are some of the VRAM size usage of the models we tested in some common scenarios.
| Model | Quantize | VRAM Size |
| --------- | --------- | --------- |
| vicuna-7b-v1.5 | 4-bit | 8 GB |
| vicuna-7b-v1.5 | 8-bit | 12 GB |
| vicuna-13b-v1.5 | 4-bit | 12 GB |
| vicuna-13b-v1.5 | 8-bit | 20 GB |
| llama-2-7b | 4-bit | 8 GB |
| llama-2-7b | 8-bit | 12 GB |
| llama-2-13b | 4-bit | 12 GB |
| llama-2-13b | 8-bit | 20 GB |
| llama-2-70b | 4-bit | 48 GB |
| llama-2-70b | 8-bit | 80 GB |
| baichuan-7b | 4-bit | 8 GB |
| baichuan-7b | 8-bit | 12 GB |
| baichuan-13b | 4-bit | 12 GB |
| baichuan-13b | 8-bit | 20 GB |

View File

@ -155,8 +155,8 @@ class Config(metaclass=Singleton):
# QLoRA
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
self.IS_LOAD_8BIT = bool(os.getenv("QUANTIZE_8bit", "True"))
self.IS_LOAD_4BIT = bool(os.getenv("QUANTIZE_4bit", "False"))
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True"
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
self.IS_LOAD_8BIT = False

View File

@ -33,6 +33,9 @@ LLM_MODEL_CONFIG = {
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
@ -59,6 +62,8 @@ LLM_MODEL_CONFIG = {
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
}
# Load model config

View File

@ -291,6 +291,11 @@ class BaichuanAdapter(BaseLLMAdaper):
return model, tokenizer
class WizardLMAdapter(BaseLLMAdaper):
def match(self, model_path: str):
return "wizardlm" in model_path.lower()
register_llm_model_adapters(VicunaLLMAdapater)
register_llm_model_adapters(ChatGLMAdapater)
register_llm_model_adapters(GuanacoAdapter)
@ -299,6 +304,7 @@ register_llm_model_adapters(GorillaAdapter)
register_llm_model_adapters(GPT4AllAdapter)
register_llm_model_adapters(Llama2Adapter)
register_llm_model_adapters(BaichuanAdapter)
register_llm_model_adapters(WizardLMAdapter)
# TODO Default support vicuna, other model need to tests and Evaluate
# just for test_py, remove this later

View File

@ -299,6 +299,21 @@ register_conv_template(
)
)
# Vicuna v1.1 template
register_conv_template(
Conversation(
name="vicuna_v1.1",
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
messages=(),
offset=0,
sep_style=SeparatorStyle.ADD_COLON_TWO,
sep=" ",
sep2="</s>",
)
)
# llama2 template
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
register_conv_template(

View File

@ -44,6 +44,8 @@ class ModelParams:
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
# TODO: vicuna-v1.5 8-bit quantization info is slow
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
model_name = model_params.model_name.lower()
supported_models = ["llama", "baichuan", "vicuna"]
return any(m in model_name for m in supported_models)
@ -89,7 +91,6 @@ class ModelLoader(metaclass=Singleton):
# TODO multi gpu support
def loader(
self,
num_gpus,
load_8bit=False,
load_4bit=False,
debug=False,
@ -100,14 +101,13 @@ class ModelLoader(metaclass=Singleton):
device=self.device,
model_path=self.model_path,
model_name=self.model_name,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
cpu_offloading=cpu_offloading,
load_8bit=load_8bit,
load_4bit=load_4bit,
debug=debug,
)
logger.info(f"model_params:\n{model_params}")
llm_adapter = get_llm_model_adapter(model_params.model_path)
return huggingface_loader(llm_adapter, model_params)
@ -126,13 +126,14 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
}
if num_gpus != 1:
kwargs["device_map"] = "auto"
kwargs["max_memory"] = max_memory
elif model_params.max_gpu_memory:
logger.info(
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
)
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
kwargs["max_memory"] = max_memory
if model_params.max_gpu_memory:
logger.info(
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
)
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
kwargs["max_memory"] = max_memory
else:
kwargs["max_memory"] = max_memory
logger.debug(f"max_memory: {max_memory}")
elif device == "mps":
@ -282,6 +283,9 @@ def load_huggingface_quantization_model(
# Loading the tokenizer
if type(model) is LlamaForCausalLM:
logger.info(
f"Current model is type of: LlamaForCausalLM, load tokenizer by LlamaTokenizer"
)
tokenizer = LlamaTokenizer.from_pretrained(
model_params.model_path, clean_up_tokenization_spaces=True
)
@ -294,6 +298,9 @@ def load_huggingface_quantization_model(
except Exception as e:
logger.warn(f"{str(e)}")
else:
logger.info(
f"Current model type is not LlamaForCausalLM, load tokenizer by AutoTokenizer"
)
tokenizer = AutoTokenizer.from_pretrained(
model_params.model_path,
trust_remote_code=model_params.trust_remote_code,

View File

@ -15,9 +15,11 @@ class BaseChatAdpter:
def match(self, model_path: str):
return True
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
"""Return the generate stream handler func"""
pass
from pilot.model.inference import generate_stream
return generate_stream
def get_conv_template(self, model_path: str) -> Conversation:
return None
@ -105,10 +107,21 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
class VicunaChatAdapter(BaseChatAdpter):
"""Model chat Adapter for vicuna"""
def match(self, model_path: str):
return "vicuna" in model_path
def _is_llama2_based(self, model_path: str):
# see https://huggingface.co/lmsys/vicuna-13b-v1.5
return "v1.5" in model_path.lower()
def get_generate_stream_func(self):
def match(self, model_path: str):
return "vicuna" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
if self._is_llama2_based(model_path):
return get_conv_template("vicuna_v1.1")
return None
def get_generate_stream_func(self, model_path: str):
if self._is_llama2_based(model_path):
return super().get_generate_stream_func(model_path)
return generate_stream
@ -118,7 +131,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "chatglm" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
return chatglm_generate_stream
@ -130,7 +143,7 @@ class CodeT5ChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "codet5" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
# TODO
pass
@ -141,7 +154,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "codegen" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
# TODO
pass
@ -152,7 +165,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "guanaco" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
return guanaco_generate_stream
@ -164,7 +177,7 @@ class FalconChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "falcon" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.falcon_llm import falcon_generate_output
return falcon_generate_output
@ -174,7 +187,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "proxyllm" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
return proxyllm_generate_stream
@ -184,7 +197,7 @@ class GorillaChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gorilla" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gorilla_llm import generate_stream
return generate_stream
@ -194,7 +207,7 @@ class GPT4AllChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "gpt4all" in model_path
def get_generate_stream_func(self):
def get_generate_stream_func(self, model_path: str):
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
return gpt4all_generate_stream
@ -207,11 +220,6 @@ class Llama2ChatAdapter(BaseChatAdpter):
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("llama-2")
def get_generate_stream_func(self):
from pilot.model.inference import generate_stream
return generate_stream
class BaichuanChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
@ -222,10 +230,13 @@ class BaichuanChatAdapter(BaseChatAdpter):
return get_conv_template("baichuan-chat")
return get_conv_template("zero_shot")
def get_generate_stream_func(self):
from pilot.model.inference import generate_stream
return generate_stream
class WizardLMChatAdapter(BaseChatAdpter):
def match(self, model_path: str):
return "wizardlm" in model_path.lower()
def get_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("vicuna_v1.1")
register_llm_model_chat_adapter(VicunaChatAdapter)
@ -236,6 +247,7 @@ register_llm_model_chat_adapter(GorillaChatAdapter)
register_llm_model_chat_adapter(GPT4AllChatAdapter)
register_llm_model_chat_adapter(Llama2ChatAdapter)
register_llm_model_chat_adapter(BaichuanChatAdapter)
register_llm_model_chat_adapter(WizardLMChatAdapter)
# Proxy model for test and develop, it's cheap for us now.
register_llm_model_chat_adapter(ProxyllmChatAdapter)

View File

@ -31,15 +31,16 @@ CFG = Config()
class ModelWorker:
def __init__(self, model_path, model_name, device, num_gpus=1):
def __init__(self, model_path, model_name, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = model_name or model_path.split("/")[-1]
self.device = device
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
self.ml = ModelLoader(model_path=model_path, model_name=self.model_name)
self.ml: ModelLoader = ModelLoader(
model_path=model_path, model_name=self.model_name
)
self.model, self.tokenizer = self.ml.loader(
num_gpus,
load_8bit=CFG.IS_LOAD_8BIT,
load_4bit=CFG.IS_LOAD_4BIT,
debug=ISDEBUG,
@ -60,7 +61,9 @@ class ModelWorker:
self.context_len = 2048
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
model_path
)
def start_check(self):
print("LLM Model Loading Success")
@ -111,9 +114,7 @@ class ModelWorker:
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
worker = ModelWorker(
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
)
worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
app = FastAPI()
# from pilot.openapi.knowledge.knowledge_controller import router