mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
feat: Support vicuna-v1.5 and WizardLM-v1.2
This commit is contained in:
parent
1388f33ddc
commit
a4574aa614
@ -82,7 +82,9 @@ Currently, we have released multiple key features, which are listed below to dem
|
||||
- Support for unstructured data such as PDF, TXT, Markdown, CSV, DOC, PPT, and WebURL
|
||||
|
||||
- Multi LLMs Support, Supports multiple large language models, currently supporting
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- WizardLM-v1.2(13b)
|
||||
- Vicuna (7b,13b)
|
||||
- ChatGLM-6b (int4,int8)
|
||||
- ChatGLM2-6b (int4,int8)
|
||||
|
15
README.zh.md
15
README.zh.md
@ -112,12 +112,15 @@ https://github.com/csunny/DB-GPT/assets/13723926/55f31781-1d49-4757-b96e-7ef6d3d
|
||||
|
||||
- 多模型支持
|
||||
- 支持多种大语言模型, 当前已支持如下模型:
|
||||
- Vicuna(7b,13b)
|
||||
- ChatGLM-6b(int4,int8)
|
||||
- guanaco(7b,13b,33b)
|
||||
- Gorilla(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- baichuan(7b,13b)
|
||||
- 🔥 Vicuna-v1.5(7b,13b)
|
||||
- 🔥 llama-2(7b,13b,70b)
|
||||
- WizardLM-v1.2(13b)
|
||||
- Vicuna (7b,13b)
|
||||
- ChatGLM-6b (int4,int8)
|
||||
- ChatGLM2-6b (int4,int8)
|
||||
- guanaco(7b,13b,33b)
|
||||
- Gorilla(7b,13b)
|
||||
- baichuan(7b,13b)
|
||||
|
||||
## 架构方案
|
||||
DB-GPT基于 [FastChat](https://github.com/lm-sys/FastChat) 构建大模型运行环境,并提供 vicuna 作为基础的大语言模型。此外,我们通过LangChain提供私域知识库问答能力。同时我们支持插件模式, 在设计上原生支持Auto-GPT插件。我们的愿景是让围绕数据库和LLM构建应用程序更加简便和便捷。
|
||||
|
@ -4,10 +4,15 @@ SCRIPT_LOCATION=$0
|
||||
cd "$(dirname "$SCRIPT_LOCATION")"
|
||||
WORK_DIR=$(pwd)
|
||||
|
||||
if [[ " $* " == *" --help "* ]] || [[ " $* " == *" -h "* ]]; then
|
||||
bash $WORK_DIR/base/build_image.sh "$@"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
bash $WORK_DIR/base/build_image.sh "$@"
|
||||
|
||||
if [ 0 -ne $? ]; then
|
||||
ehco "Error: build base image failed"
|
||||
echo "Error: build base image failed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -48,6 +48,7 @@ Notice make sure you have install git-lfs
|
||||
```
|
||||
|
||||
```bash
|
||||
git clone https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
git clone https://huggingface.co/Tribbiani/vicuna-13b
|
||||
git clone https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2
|
||||
git clone https://huggingface.co/GanymedeNil/text2vec-large-chinese
|
||||
@ -62,6 +63,8 @@ cp .env.template .env
|
||||
|
||||
You can configure basic parameters in the .env file, for example setting LLM_MODEL to the model to be used
|
||||
|
||||
([Vicuna-v1.5](https://huggingface.co/lmsys/vicuna-13b-v1.5) based on llama-2 has been released, we recommend you set `LLM_MODEL=vicuna-13b-v1.5` to try this model)
|
||||
|
||||
### 3. Run
|
||||
You can refer to this document to obtain the Vicuna weights: [Vicuna](https://github.com/lm-sys/FastChat/blob/main/README.md#model-weights) .
|
||||
|
||||
@ -107,6 +110,16 @@ db-gpt-allinone latest e1ffd20b85ac 45 minutes ago 14.5GB
|
||||
db-gpt latest e36fb0cca5d9 3 hours ago 14GB
|
||||
```
|
||||
|
||||
You can pass some parameters to docker/build_all_images.sh.
|
||||
```bash
|
||||
$ bash docker/build_all_images.sh \
|
||||
--base-image nvidia/cuda:11.8.0-devel-ubuntu22.04 \
|
||||
--pip-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
--language zh
|
||||
```
|
||||
|
||||
You can execute the command `bash docker/build_all_images.sh --help` to see more usage.
|
||||
|
||||
#### 4.2. Run all in one docker container
|
||||
|
||||
**Run with local model**
|
||||
@ -158,7 +171,7 @@ $ docker run --gpus "device=0" -d -p 3306:3306 \
|
||||
- `-e LLM_MODEL=proxyllm`, means we use proxy llm(openai interface, fastchat interface...)
|
||||
- `-v /data/models/text2vec-large-chinese:/app/models/text2vec-large-chinese`, means we mount the local text2vec model to the docker container.
|
||||
|
||||
#### 4.2. Run with docker compose
|
||||
#### 4.3. Run with docker compose
|
||||
|
||||
```bash
|
||||
$ docker compose up -d
|
||||
@ -197,6 +210,8 @@ CUDA_VISIBLE_DEVICES=0 python3 pilot/server/dbgpt_server.py
|
||||
CUDA_VISIBLE_DEVICES=3,4,5,6 python3 pilot/server/dbgpt_server.py
|
||||
````
|
||||
|
||||
You can modify the setting `MAX_GPU_MEMORY=xxGib` in `.env` file to configure the maximum memory used by each GPU.
|
||||
|
||||
### 6. Not Enough Memory
|
||||
|
||||
DB-GPT supported 8-bit quantization and 4-bit quantization.
|
||||
@ -205,4 +220,24 @@ You can modify the setting `QUANTIZE_8bit=True` or `QUANTIZE_4bit=True` in `.env
|
||||
|
||||
Llama-2-70b with 8-bit quantization can run with 80 GB of VRAM, and 4-bit quantization can run with 48 GB of VRAM.
|
||||
|
||||
Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
|
||||
Note: you need to install the latest dependencies according to [requirements.txt](https://github.com/eosphoros-ai/DB-GPT/blob/main/requirements.txt).
|
||||
|
||||
|
||||
Here are some of the VRAM size usage of the models we tested in some common scenarios.
|
||||
|
||||
| Model | Quantize | VRAM Size |
|
||||
| --------- | --------- | --------- |
|
||||
| vicuna-7b-v1.5 | 4-bit | 8 GB |
|
||||
| vicuna-7b-v1.5 | 8-bit | 12 GB |
|
||||
| vicuna-13b-v1.5 | 4-bit | 12 GB |
|
||||
| vicuna-13b-v1.5 | 8-bit | 20 GB |
|
||||
| llama-2-7b | 4-bit | 8 GB |
|
||||
| llama-2-7b | 8-bit | 12 GB |
|
||||
| llama-2-13b | 4-bit | 12 GB |
|
||||
| llama-2-13b | 8-bit | 20 GB |
|
||||
| llama-2-70b | 4-bit | 48 GB |
|
||||
| llama-2-70b | 8-bit | 80 GB |
|
||||
| baichuan-7b | 4-bit | 8 GB |
|
||||
| baichuan-7b | 8-bit | 12 GB |
|
||||
| baichuan-13b | 4-bit | 12 GB |
|
||||
| baichuan-13b | 8-bit | 20 GB |
|
@ -155,8 +155,8 @@ class Config(metaclass=Singleton):
|
||||
|
||||
# QLoRA
|
||||
self.QLoRA = os.getenv("QUANTIZE_QLORA", "True")
|
||||
self.IS_LOAD_8BIT = bool(os.getenv("QUANTIZE_8bit", "True"))
|
||||
self.IS_LOAD_4BIT = bool(os.getenv("QUANTIZE_4bit", "False"))
|
||||
self.IS_LOAD_8BIT = os.getenv("QUANTIZE_8bit", "True") == "True"
|
||||
self.IS_LOAD_4BIT = os.getenv("QUANTIZE_4bit", "False") == "True"
|
||||
if self.IS_LOAD_8BIT and self.IS_LOAD_4BIT:
|
||||
self.IS_LOAD_8BIT = False
|
||||
|
||||
|
@ -33,6 +33,9 @@ LLM_MODEL_CONFIG = {
|
||||
"flan-t5-base": os.path.join(MODEL_PATH, "flan-t5-base"),
|
||||
"vicuna-13b": os.path.join(MODEL_PATH, "vicuna-13b"),
|
||||
"vicuna-7b": os.path.join(MODEL_PATH, "vicuna-7b"),
|
||||
# (Llama2 based) see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"vicuna-13b-v1.5": os.path.join(MODEL_PATH, "vicuna-13b-v1.5"),
|
||||
"vicuna-7b-v1.5": os.path.join(MODEL_PATH, "vicuna-7b-v1.5"),
|
||||
"text2vec": os.path.join(MODEL_PATH, "text2vec-large-chinese"),
|
||||
"sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"),
|
||||
"codegen2-1b": os.path.join(MODEL_PATH, "codegen2-1B"),
|
||||
@ -59,6 +62,8 @@ LLM_MODEL_CONFIG = {
|
||||
"baichuan-13b": os.path.join(MODEL_PATH, "Baichuan-13B-Chat"),
|
||||
# please rename "fireballoon/baichuan-vicuna-chinese-7b" to "baichuan-7b"
|
||||
"baichuan-7b": os.path.join(MODEL_PATH, "baichuan-7b"),
|
||||
# (Llama2 based) We only support WizardLM-13B-V1.2 for now, which is trained from Llama-2 13b, see https://huggingface.co/WizardLM/WizardLM-13B-V1.2
|
||||
"wizardlm-13b": os.path.join(MODEL_PATH, "WizardLM-13B-V1.2"),
|
||||
}
|
||||
|
||||
# Load model config
|
||||
|
@ -291,6 +291,11 @@ class BaichuanAdapter(BaseLLMAdaper):
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
class WizardLMAdapter(BaseLLMAdaper):
|
||||
def match(self, model_path: str):
|
||||
return "wizardlm" in model_path.lower()
|
||||
|
||||
|
||||
register_llm_model_adapters(VicunaLLMAdapater)
|
||||
register_llm_model_adapters(ChatGLMAdapater)
|
||||
register_llm_model_adapters(GuanacoAdapter)
|
||||
@ -299,6 +304,7 @@ register_llm_model_adapters(GorillaAdapter)
|
||||
register_llm_model_adapters(GPT4AllAdapter)
|
||||
register_llm_model_adapters(Llama2Adapter)
|
||||
register_llm_model_adapters(BaichuanAdapter)
|
||||
register_llm_model_adapters(WizardLMAdapter)
|
||||
# TODO Default support vicuna, other model need to tests and Evaluate
|
||||
|
||||
# just for test_py, remove this later
|
||||
|
@ -299,6 +299,21 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Vicuna v1.1 template
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="vicuna_v1.1",
|
||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(),
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
||||
sep=" ",
|
||||
sep2="</s>",
|
||||
)
|
||||
)
|
||||
|
||||
# llama2 template
|
||||
# reference: https://github.com/facebookresearch/llama/blob/cfc3fc8c1968d390eb830e65c63865e980873a06/llama/generation.py#L212
|
||||
register_conv_template(
|
||||
|
@ -44,6 +44,8 @@ class ModelParams:
|
||||
|
||||
|
||||
def _check_multi_gpu_or_4bit_quantization(model_params: ModelParams):
|
||||
# TODO: vicuna-v1.5 8-bit quantization info is slow
|
||||
# TODO: support wizardlm quantization, see: https://huggingface.co/WizardLM/WizardLM-13B-V1.2/discussions/5
|
||||
model_name = model_params.model_name.lower()
|
||||
supported_models = ["llama", "baichuan", "vicuna"]
|
||||
return any(m in model_name for m in supported_models)
|
||||
@ -89,7 +91,6 @@ class ModelLoader(metaclass=Singleton):
|
||||
# TODO multi gpu support
|
||||
def loader(
|
||||
self,
|
||||
num_gpus,
|
||||
load_8bit=False,
|
||||
load_4bit=False,
|
||||
debug=False,
|
||||
@ -100,14 +101,13 @@ class ModelLoader(metaclass=Singleton):
|
||||
device=self.device,
|
||||
model_path=self.model_path,
|
||||
model_name=self.model_name,
|
||||
num_gpus=num_gpus,
|
||||
max_gpu_memory=max_gpu_memory,
|
||||
cpu_offloading=cpu_offloading,
|
||||
load_8bit=load_8bit,
|
||||
load_4bit=load_4bit,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
logger.info(f"model_params:\n{model_params}")
|
||||
llm_adapter = get_llm_model_adapter(model_params.model_path)
|
||||
return huggingface_loader(llm_adapter, model_params)
|
||||
|
||||
@ -126,13 +126,14 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParams):
|
||||
}
|
||||
if num_gpus != 1:
|
||||
kwargs["device_map"] = "auto"
|
||||
kwargs["max_memory"] = max_memory
|
||||
elif model_params.max_gpu_memory:
|
||||
logger.info(
|
||||
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
|
||||
)
|
||||
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
|
||||
kwargs["max_memory"] = max_memory
|
||||
if model_params.max_gpu_memory:
|
||||
logger.info(
|
||||
f"There has max_gpu_memory from config: {model_params.max_gpu_memory}"
|
||||
)
|
||||
max_memory = {i: model_params.max_gpu_memory for i in range(num_gpus)}
|
||||
kwargs["max_memory"] = max_memory
|
||||
else:
|
||||
kwargs["max_memory"] = max_memory
|
||||
logger.debug(f"max_memory: {max_memory}")
|
||||
|
||||
elif device == "mps":
|
||||
@ -282,6 +283,9 @@ def load_huggingface_quantization_model(
|
||||
|
||||
# Loading the tokenizer
|
||||
if type(model) is LlamaForCausalLM:
|
||||
logger.info(
|
||||
f"Current model is type of: LlamaForCausalLM, load tokenizer by LlamaTokenizer"
|
||||
)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_params.model_path, clean_up_tokenization_spaces=True
|
||||
)
|
||||
@ -294,6 +298,9 @@ def load_huggingface_quantization_model(
|
||||
except Exception as e:
|
||||
logger.warn(f"{str(e)}")
|
||||
else:
|
||||
logger.info(
|
||||
f"Current model type is not LlamaForCausalLM, load tokenizer by AutoTokenizer"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_params.model_path,
|
||||
trust_remote_code=model_params.trust_remote_code,
|
||||
|
@ -15,9 +15,11 @@ class BaseChatAdpter:
|
||||
def match(self, model_path: str):
|
||||
return True
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
"""Return the generate stream handler func"""
|
||||
pass
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return None
|
||||
@ -105,10 +107,21 @@ def get_llm_chat_adapter(model_path: str) -> BaseChatAdpter:
|
||||
class VicunaChatAdapter(BaseChatAdpter):
|
||||
"""Model chat Adapter for vicuna"""
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path
|
||||
def _is_llama2_based(self, model_path: str):
|
||||
# see https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
return "v1.5" in model_path.lower()
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def match(self, model_path: str):
|
||||
return "vicuna" in model_path.lower()
|
||||
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
if self._is_llama2_based(model_path):
|
||||
return get_conv_template("vicuna_v1.1")
|
||||
return None
|
||||
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
if self._is_llama2_based(model_path):
|
||||
return super().get_generate_stream_func(model_path)
|
||||
return generate_stream
|
||||
|
||||
|
||||
@ -118,7 +131,7 @@ class ChatGLMChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "chatglm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.chatglm_llm import chatglm_generate_stream
|
||||
|
||||
return chatglm_generate_stream
|
||||
@ -130,7 +143,7 @@ class CodeT5ChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "codet5" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@ -141,7 +154,7 @@ class CodeGenChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "codegen" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
# TODO
|
||||
pass
|
||||
|
||||
@ -152,7 +165,7 @@ class GuanacoChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "guanaco" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.guanaco_llm import guanaco_generate_stream
|
||||
|
||||
return guanaco_generate_stream
|
||||
@ -164,7 +177,7 @@ class FalconChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "falcon" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.falcon_llm import falcon_generate_output
|
||||
|
||||
return falcon_generate_output
|
||||
@ -174,7 +187,7 @@ class ProxyllmChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.proxy_llm import proxyllm_generate_stream
|
||||
|
||||
return proxyllm_generate_stream
|
||||
@ -184,7 +197,7 @@ class GorillaChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "gorilla" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.gorilla_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
@ -194,7 +207,7 @@ class GPT4AllChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "gpt4all" in model_path
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
def get_generate_stream_func(self, model_path: str):
|
||||
from pilot.model.llm_out.gpt4all_llm import gpt4all_generate_stream
|
||||
|
||||
return gpt4all_generate_stream
|
||||
@ -207,11 +220,6 @@ class Llama2ChatAdapter(BaseChatAdpter):
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("llama-2")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
|
||||
class BaichuanChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
@ -222,10 +230,13 @@ class BaichuanChatAdapter(BaseChatAdpter):
|
||||
return get_conv_template("baichuan-chat")
|
||||
return get_conv_template("zero_shot")
|
||||
|
||||
def get_generate_stream_func(self):
|
||||
from pilot.model.inference import generate_stream
|
||||
|
||||
return generate_stream
|
||||
class WizardLMChatAdapter(BaseChatAdpter):
|
||||
def match(self, model_path: str):
|
||||
return "wizardlm" in model_path.lower()
|
||||
|
||||
def get_conv_template(self, model_path: str) -> Conversation:
|
||||
return get_conv_template("vicuna_v1.1")
|
||||
|
||||
|
||||
register_llm_model_chat_adapter(VicunaChatAdapter)
|
||||
@ -236,6 +247,7 @@ register_llm_model_chat_adapter(GorillaChatAdapter)
|
||||
register_llm_model_chat_adapter(GPT4AllChatAdapter)
|
||||
register_llm_model_chat_adapter(Llama2ChatAdapter)
|
||||
register_llm_model_chat_adapter(BaichuanChatAdapter)
|
||||
register_llm_model_chat_adapter(WizardLMChatAdapter)
|
||||
|
||||
# Proxy model for test and develop, it's cheap for us now.
|
||||
register_llm_model_chat_adapter(ProxyllmChatAdapter)
|
||||
|
@ -31,15 +31,16 @@ CFG = Config()
|
||||
|
||||
|
||||
class ModelWorker:
|
||||
def __init__(self, model_path, model_name, device, num_gpus=1):
|
||||
def __init__(self, model_path, model_name, device):
|
||||
if model_path.endswith("/"):
|
||||
model_path = model_path[:-1]
|
||||
self.model_name = model_name or model_path.split("/")[-1]
|
||||
self.device = device
|
||||
print(f"Loading {model_name} LLM ModelServer in {device}! Please Wait......")
|
||||
self.ml = ModelLoader(model_path=model_path, model_name=self.model_name)
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=model_path, model_name=self.model_name
|
||||
)
|
||||
self.model, self.tokenizer = self.ml.loader(
|
||||
num_gpus,
|
||||
load_8bit=CFG.IS_LOAD_8BIT,
|
||||
load_4bit=CFG.IS_LOAD_4BIT,
|
||||
debug=ISDEBUG,
|
||||
@ -60,7 +61,9 @@ class ModelWorker:
|
||||
self.context_len = 2048
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func()
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
model_path
|
||||
)
|
||||
|
||||
def start_check(self):
|
||||
print("LLM Model Loading Success!")
|
||||
@ -111,9 +114,7 @@ class ModelWorker:
|
||||
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
worker = ModelWorker(
|
||||
model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE, num_gpus=1
|
||||
)
|
||||
worker = ModelWorker(model_path=model_path, model_name=CFG.LLM_MODEL, device=DEVICE)
|
||||
|
||||
app = FastAPI()
|
||||
# from pilot.openapi.knowledge.knowledge_controller import router
|
||||
|
Loading…
Reference in New Issue
Block a user