Merge branch 'main' into tt_dev

This commit is contained in:
csunny 2023-10-12 09:23:06 +08:00
commit eef2b381aa
39 changed files with 1001 additions and 197 deletions

4
.gitignore vendored
View File

@ -151,4 +151,6 @@ pilot/mock_datas/db-gpt-test.db.wal
logswebserver.log.* logswebserver.log.*
.history/* .history/*
.plugin_env .plugin_env
# Ignore for now
thirdparty

View File

@ -10,11 +10,11 @@ git clone https://github.com/<YOUR-GITHUB-USERNAME>/DB-GPT
``` ```
3. Install the project requirements 3. Install the project requirements
``` ```
pip install -r requirements.txt pip install -r requirements/dev-requirements.txt
``` ```
4. Install pre-commit hooks 4. Install pre-commit hooks
``` ```
pre-commit install pre-commit install --allow-missing-config
``` ```
5. Create a new branch for your changes using the following command: 5. Create a new branch for your changes using the following command:

View File

@ -41,6 +41,7 @@ services:
restart: unless-stopped restart: unless-stopped
networks: networks:
- dbgptnet - dbgptnet
ipc: host
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@ -1,6 +1,6 @@
#!/bin/bash #!/bin/bash
docker run --gpus all -d \ docker run --ipc host --gpus all -d \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \ -e LOCAL_DB_PATH=data/default_sqlite.db \

View File

@ -4,7 +4,7 @@
PROXY_API_KEY="$PROXY_API_KEY" PROXY_API_KEY="$PROXY_API_KEY"
PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}" PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}"
docker run --gpus all -d \ docker run --ipc host --gpus all -d \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \ -e LOCAL_DB_PATH=data/default_sqlite.db \

View File

@ -21,6 +21,7 @@ services:
restart: unless-stopped restart: unless-stopped
networks: networks:
- dbgptnet - dbgptnet
ipc: host
deploy: deploy:
resources: resources:
reservations: reservations:

View File

@ -47,7 +47,7 @@ You can execute the command `bash docker/build_all_images.sh --help` to see more
**Run with local model and SQLite database** **Run with local model and SQLite database**
```bash ```bash
docker run --gpus all -d \ docker run --ipc host --gpus all -d \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \ -e LOCAL_DB_PATH=data/default_sqlite.db \
@ -73,7 +73,7 @@ docker logs dbgpt -f
**Run with local model and MySQL database** **Run with local model and MySQL database**
```bash ```bash
docker run --gpus all -d -p 3306:3306 \ docker run --ipc host --gpus all -d -p 3306:3306 \
-p 5000:5000 \ -p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \ -e LOCAL_DB_PASSWORD=aa123456 \

View File

@ -30,3 +30,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
./llama/llama_cpp.md ./llama/llama_cpp.md
./quantization/quantization.md ./quantization/quantization.md
./vllm/vllm.md

View File

@ -0,0 +1,26 @@
vLLM
==================================
[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving.
## Running vLLM
### Installing Dependencies
vLLM is an optional dependency in DB-GPT, and you can manually install it using the following command:
```bash
pip install -e ".[vllm]"
```
### Modifying the Configuration File
Next, you can directly modify your `.env` file to enable vllm.
```env
LLM_MODEL=vicuna-13b-v1.5
MODEL_TYPE=vllm
```
You can view the models supported by vLLM [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html#supported-models)
Then you can run it according to [Run](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run).

View File

@ -0,0 +1,79 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.9\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-10-09 19:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/llm/vllm/vllm.md:1
#: 9193438ba52148a3b71f190beeb4ef42
msgid "vLLM"
msgstr ""
#: ../../getting_started/install/llm/vllm/vllm.md:4
#: c30d032965794d7e81636581324be45d
msgid ""
"[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use "
"library for LLM inference and serving."
msgstr "[vLLM](https://github.com/vllm-project/vllm) 是一个快速且易于使用的 LLM 推理和服务的库。"
#: ../../getting_started/install/llm/vllm/vllm.md:6
#: b399c7268e0448cb893fbcb11a480849
msgid "Running vLLM"
msgstr "运行 vLLM"
#: ../../getting_started/install/llm/vllm/vllm.md:8
#: 7bed52b8bac946069a24df9e94098df5
msgid "Installing Dependencies"
msgstr "安装依赖"
#: ../../getting_started/install/llm/vllm/vllm.md:10
#: fd50a9f3e1b1459daa3b1a0cd610d1a3
msgid ""
"vLLM is an optional dependency in DB-GPT, and you can manually install it"
" using the following command:"
msgstr "vLLM 在 DB-GPT 是一个可选依赖, 你可以使用下面的命令手动安装它:"
#: ../../getting_started/install/llm/vllm/vllm.md:16
#: 44b251bc6b2c41ebaad9fd5a6a204c7c
msgid "Modifying the Configuration File"
msgstr "修改配置文件"
#: ../../getting_started/install/llm/vllm/vllm.md:18
#: 37f4e65148fa4339969265107b70b8fe
msgid "Next, you can directly modify your `.env` file to enable vllm."
msgstr "你可以直接修改你的 `.env` 文件。"
#: ../../getting_started/install/llm/vllm/vllm.md:24
#: 15d79c9417d04e779fa00a08a05e30d7
msgid ""
"You can view the models supported by vLLM "
"[here](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
"#supported-models)"
msgstr ""
"你可以在 "
"[这里](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
"#supported-models) 查看 vLLM 支持的模型。"
#: ../../getting_started/install/llm/vllm/vllm.md:26
#: 28d90b1fdf6943d9969c8668a7c1094b
msgid ""
"Then you can run it according to [Run](https://db-"
"gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run)."
msgstr ""
"然后你可以根据[运行]"
"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/getting_started/install/deploy/deploy.html#run)来启动项目。"

View File

@ -204,6 +204,9 @@ class Config(metaclass=Singleton):
self.SYSTEM_APP: Optional["SystemApp"] = None self.SYSTEM_APP: Optional["SystemApp"] = None
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
def set_debug_mode(self, value: bool) -> None: def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value""" """Set the debug mode value"""
self.debug_mode = value self.debug_mode = value

View File

@ -78,6 +78,8 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"), "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"), "internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
} }
EMBEDDING_MODEL_CONFIG = { EMBEDDING_MODEL_CONFIG = {

View File

@ -29,7 +29,7 @@ class SeparatorStyle(Enum):
@dataclasses.dataclass @dataclasses.dataclass
class Conversation: class OldConversation:
"""This class keeps all conversation history.""" """This class keeps all conversation history."""
system: str system: str
@ -81,7 +81,7 @@ class Conversation:
return ret return ret
def copy(self): def copy(self):
return Conversation( return OldConversation(
system=self.system, system=self.system,
roles=self.roles, roles=self.roles,
messages=[[x, y] for x, y in self.messages], messages=[[x, y] for x, y in self.messages],
@ -104,7 +104,7 @@ class Conversation:
} }
conv_default = Conversation( conv_default = OldConversation(
system=None, system=None,
roles=("human", "ai"), roles=("human", "ai"),
messages=[], messages=[],
@ -148,7 +148,7 @@ conv_default = Conversation(
# ) # )
conv_one_shot = Conversation( conv_one_shot = OldConversation(
system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:", system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:",
roles=("USER", "ASSISTANT"), roles=("USER", "ASSISTANT"),
messages=( messages=(
@ -179,7 +179,7 @@ conv_one_shot = Conversation(
sep="###", sep="###",
) )
conv_vicuna_v1 = Conversation( conv_vicuna_v1 = OldConversation(
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"), roles=("USER", "ASSISTANT"),
@ -190,7 +190,7 @@ conv_vicuna_v1 = Conversation(
sep2="</s>", sep2="</s>",
) )
auto_dbgpt_one_shot = Conversation( auto_dbgpt_one_shot = OldConversation(
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. " system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"), roles=("USER", "ASSISTANT"),
@ -263,7 +263,7 @@ auto_dbgpt_one_shot = Conversation(
sep="###", sep="###",
) )
auto_dbgpt_without_shot = Conversation( auto_dbgpt_without_shot = OldConversation(
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. " system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"), roles=("USER", "ASSISTANT"),

View File

@ -1,3 +1,7 @@
"""
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
@ -13,6 +17,8 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
LlamaTokenizer, LlamaTokenizer,
) )
from pilot.model.base import ModelType
from pilot.model.parameter import ( from pilot.model.parameter import (
ModelParameters, ModelParameters,
LlamaCppModelParameters, LlamaCppModelParameters,
@ -26,15 +32,6 @@ logger = logging.getLogger(__name__)
CFG = Config() CFG = Config()
class ModelType:
""" "Type of model"""
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
PROXY = "proxy"
# TODO, support more model type
class BaseLLMAdaper: class BaseLLMAdaper:
"""The Base class for multi model, in our project. """The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT""" We will support those model, which performance resemble ChatGPT"""
@ -95,33 +92,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
) )
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
if model_name is None:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()
return [param_class]
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters: def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
try: try:
llm_adapter = get_llm_model_adapter(model_name, model_path) llm_adapter = get_llm_model_adapter(model_name, model_path)

View File

@ -15,6 +15,16 @@ class Message(TypedDict):
content: str content: str
class ModelType:
""" "Type of model"""
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
PROXY = "proxy"
VLLM = "vllm"
# TODO, support more model type
@dataclass @dataclass
class ModelInstance: class ModelInstance:
"""Model instance info""" """Model instance info"""

View File

@ -404,7 +404,7 @@ def stop_model_controller(port: int):
def _model_dynamic_factory() -> Callable[[None], List[Type]]: def _model_dynamic_factory() -> Callable[[None], List[Type]]:
from pilot.model.adapter import _dynamic_model_parser from pilot.model.model_adapter import _dynamic_model_parser
param_class = _dynamic_model_parser() param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters] fix_class = [ModelWorkerParameters]

View File

@ -1,26 +1,34 @@
import os
import logging import logging
from typing import Dict, Iterator, List from typing import Dict, Iterator, List
from pilot.configs.model_config import get_device from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.base import ModelOutput from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_model_cache from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_torch_imported = False
try:
import torch
_torch_imported = True
except ImportError:
pass
class DefaultModelWorker(ModelWorker): class DefaultModelWorker(ModelWorker):
def __init__(self) -> None: def __init__(self) -> None:
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
self._model_params = None self._model_params = None
self.llm_adapter: BaseLLMAdaper = None self.llm_adapter: LLMModelAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"): if model_path.endswith("/"):
@ -29,18 +37,24 @@ class DefaultModelWorker(ModelWorker):
self.model_name = model_name self.model_name = model_name
self.model_path = model_path self.model_path = model_path
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) model_type = kwargs.get("model_type")
### Temporary configuration, fastchat will be used by default in the future.
use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.llm_adapter = get_llm_model_adapter(
self.model_name,
self.model_path,
use_fastchat=use_fastchat,
model_type=model_type,
)
model_type = self.llm_adapter.model_type() model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type) self.param_cls = self.llm_adapter.model_param_class(model_type)
self._support_async = self.llm_adapter.support_async()
logger.info( logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}" f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
) )
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
self.model_path
)
self.ml: ModelLoader = ModelLoader( self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name model_path=self.model_path, model_name=self.model_name
) )
@ -50,6 +64,9 @@ class DefaultModelWorker(ModelWorker):
def model_param_class(self) -> ModelParameters: def model_param_class(self) -> ModelParameters:
return self.param_cls return self.param_cls
def support_async(self) -> bool:
return self._support_async
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class() param_cls = self.model_param_class()
model_args = EnvArgumentParser() model_args = EnvArgumentParser()
@ -77,7 +94,9 @@ class DefaultModelWorker(ModelWorker):
model_params = self.parse_parameters(command_args) model_params = self.parse_parameters(command_args)
self._model_params = model_params self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}") logger.info(f"Begin load model, model params: {model_params}")
self.model, self.tokenizer = self.ml.loader_with_params(model_params) self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
def stop(self) -> None: def stop(self) -> None:
if not self.model: if not self.model:
@ -90,51 +109,26 @@ class DefaultModelWorker(ModelWorker):
_clear_model_cache(self._model_params.device) _clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False
try: try:
import torch params, model_context, generate_stream_func = self._prepare_generate_stream(
params
torch_imported = True
except ImportError:
pass
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
) )
previous_response = "" previous_response = ""
print("stream output:\n")
for output in self.generate_stream_func( for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len self.model, self.tokenizer, params, get_device(), self.context_len
): ):
# Please do not open the output in production! model_output, incremental_output, output_str = self._handle_output(
# The gpt4all thread shares stdout with the parent process, output, previous_response, model_context
# and opening it may affect the frontend output.
incremental_output = output[len(previous_response) :]
# print("output: ", output)
print(incremental_output, end="", flush=True)
previous_response = output
# return some model context to dgt-server
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
) )
previous_response = output_str
yield model_output yield model_output
print( print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
) )
except Exception as e: except Exception as e:
# Check if the exception is a torch.cuda.CudaError and if torch was imported. yield self._handle_exception(e)
if torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
yield model_output
def generate(self, params: Dict) -> ModelOutput: def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result""" """Generate non stream result"""
@ -145,3 +139,81 @@ class DefaultModelWorker(ModelWorker):
def embeddings(self, params: Dict) -> List[List[float]]: def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
try:
params, model_context, generate_stream_func = self._prepare_generate_stream(
params
)
previous_response = ""
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
)
previous_response = output_str
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
except Exception as e:
yield self._handle_exception(e)
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
def _prepare_generate_stream(self, params: Dict):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
self.model_path,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
if self.support_async():
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
stream_type = "async "
logger.info(
"current generate stream function is asynchronous stream function"
)
else:
generate_stream_func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
return params, model_context, generate_stream_func
def _handle_output(self, output, previous_response, model_context):
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
)
return model_output, incremental_output, output
def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return model_output

View File

@ -3,7 +3,9 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conver
Conversation prompt templates. Conversation prompt templates.
TODO Using fastchat core package
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
""" """
import dataclasses import dataclasses

View File

@ -1,6 +1,8 @@
""" """
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
""" """
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-

View File

@ -7,7 +7,6 @@ import time
from typing import Optional from typing import Optional
from pilot.configs.config import Config from pilot.configs.config import Config
from pilot.conversation import Conversation
# TODO Rewrite this # TODO Rewrite this
@ -44,36 +43,6 @@ def retry_stream_api(
return _wrapper return _wrapper
# Overly simple abstraction util we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_competion(
conv: Conversation,
model: Optional[str] = None,
temperature: float = None,
max_new_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the Vicuna-13b
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Default to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None.
Returns:
str: The response from the chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
# TODO request vicuna model get response
# convert vicuna message to chat completion.
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion():
pass
class ChatIO(abc.ABC): class ChatIO(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def prompt_for_input(self, role: str) -> str: def prompt_for_input(self, role: str) -> str:

View File

@ -1,7 +1,6 @@
import torch import torch
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):

View File

@ -0,0 +1,56 @@
from typing import Dict
from vllm import AsyncLLMEngine
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
async def generate_stream(
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
):
"""
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
"""
prompt = params["prompt"]
request_id = params.pop("request_id") if "request_id" in params else random_uuid()
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if tokenizer.eos_token_id is not None:
stop_token_ids.append(tokenizer.eos_token_id)
# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)
for tid in stop_token_ids:
if tid is not None:
stop.add(tokenizer.decode(tid))
# make sampling params in vllm
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
)
results_generator = model.generate(prompt, sampling_params, request_id)
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [prompt + output.text for output in request_output.outputs]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
yield {"text": text_outputs, "error_code": 0, "usage": {}}

View File

@ -146,7 +146,7 @@ def list_supported_models():
def _list_supported_models( def _list_supported_models(
worker_type: str, model_config: Dict[str, str] worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]: ) -> List[SupportedModel]:
from pilot.model.adapter import get_llm_model_adapter from pilot.model.model_adapter import get_llm_model_adapter
from pilot.model.parameter import ModelParameters from pilot.model.parameter import ModelParameters
from pilot.model.loader import _get_model_real_path from pilot.model.loader import _get_model_real_path

View File

@ -3,9 +3,11 @@
from typing import Optional, Dict from typing import Optional, Dict
from dataclasses import asdict
import logging import logging
from pilot.configs.model_config import get_device from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.base import ModelType
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.parameter import ( from pilot.model.parameter import (
ModelParameters, ModelParameters,
LlamaCppModelParameters, LlamaCppModelParameters,
@ -114,8 +116,9 @@ class ModelLoader:
else: else:
raise Exception(f"Unkown model type {model_type}") raise Exception(f"Unkown model type {model_type}")
def loader_with_params(self, model_params: ModelParameters): def loader_with_params(
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
):
model_type = llm_adapter.model_type() model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template self.prompt_template = model_params.prompt_template
if model_type == ModelType.HF: if model_type == ModelType.HF:
@ -124,11 +127,13 @@ class ModelLoader:
return llamacpp_loader(llm_adapter, model_params) return llamacpp_loader(llm_adapter, model_params)
elif model_type == ModelType.PROXY: elif model_type == ModelType.PROXY:
return proxyllm_loader(llm_adapter, model_params) return proxyllm_loader(llm_adapter, model_params)
elif model_type == ModelType.VLLM:
return llm_adapter.load_from_params(model_params)
else: else:
raise Exception(f"Unkown model type {model_type}") raise Exception(f"Unkown model type {model_type}")
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters): def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
import torch import torch
from pilot.model.compression import compress_module from pilot.model.compression import compress_module
@ -181,7 +186,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
f"Current model {model_params.model_name} not supported quantization" f"Current model {model_params.model_name} not supported quantization"
) )
# default loader # default loader
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs) model, tokenizer = llm_adapter.load(model_params.model_path, kwargs)
if model_params.load_8bit and num_gpus == 1 and tokenizer: if model_params.load_8bit and num_gpus == 1 and tokenizer:
# TODO merge current code into `load_huggingface_quantization_model` # TODO merge current code into `load_huggingface_quantization_model`
@ -204,7 +209,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
def load_huggingface_quantization_model( def load_huggingface_quantization_model(
llm_adapter: BaseLLMAdaper, llm_adapter: LLMModelAdaper,
model_params: ModelParameters, model_params: ModelParameters,
kwargs: Dict, kwargs: Dict,
max_memory: Dict[int, str], max_memory: Dict[int, str],
@ -339,7 +344,7 @@ def load_huggingface_quantization_model(
return model, tokenizer return model, tokenizer
def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters): def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
try: try:
from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel
except ImportError as exc: except ImportError as exc:
@ -353,7 +358,7 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
return model, tokenizer return model, tokenizer
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters): def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
from pilot.model.proxy.llms.proxy_model import ProxyModel from pilot.model.proxy.llms.proxy_model import ProxyModel
logger.info("Load proxyllm") logger.info("Load proxyllm")

View File

@ -0,0 +1,431 @@
from __future__ import annotations
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
import dataclasses
import logging
import threading
import os
from functools import cache
from pilot.model.base import ModelType
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.utils.parameter_utils import (
_extract_parameter_details,
_build_parameter_class,
_get_dataclass_print_str,
)
try:
from fastchat.conversation import (
Conversation,
register_conv_template,
SeparatorStyle,
)
except ImportError as exc:
raise ValueError(
"Could not import python package: fschat "
"Please install fastchat by command `pip install fschat` "
) from exc
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
from pilot.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
from torch.nn import Module as TorchNNModule
logger = logging.getLogger(__name__)
thread_local = threading.local()
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
]
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
def model_type(self) -> str:
return ModelType.HF
def model_param_class(self, model_type: str = None) -> ModelParameters:
"""Get the startup parameters instance of the model"""
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def load(self, model_path: str, from_pretrained_kwargs: dict):
"""Load model and tokenizer"""
raise NotImplementedError
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
raise NotImplementedError
def get_async_generate_stream_function(self, model, model_path: str):
"""Get the asynchronous generate stream function of the model"""
raise NotImplementedError
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
"""Get the default conv template"""
raise NotImplementedError
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_default_conv_template(model_name, model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
conv = conv.copy()
system_messages = []
for message in messages:
role, content = None, None
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if system_messages:
conv.set_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
# Overwrite model params:
params["stop"] = conv.stop_str
return params, model_context
class OldLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping old adapter, which may be removed later"""
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
self._adapter = adapter
self._chat_adapter = chat_adapter
def model_type(self) -> str:
return self._adapter.model_type()
def model_param_class(self, model_type: str = None) -> ModelParameters:
return self._adapter.model_param_class(model_type)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._chat_adapter.get_conv_template(model_path)
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.loader(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping fastchat adapter"""
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
from fastchat.model.model_adapter import get_generate_stream_function
return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._adapter.get_default_conv_template(model_path)
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
return get_conv_template(name)
@cache
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
try:
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
return adapter.get_default_conv_template(model_name, model_path)
except Exception:
return None
@cache
def get_llm_model_adapter(
model_name: str,
model_path: str,
use_fastchat: bool = True,
use_fastchat_monkey_patch: bool = False,
model_type: str = None,
) -> LLMModelAdaper:
if model_type == ModelType.VLLM:
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()
must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
adapter = _get_fastchat_model_adapter(
model_name,
model_path,
_fastchat_get_adapter_monkey_patch,
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
)
return FastChatLLMModelAdaperWrapper(adapter)
else:
from pilot.model.adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
)
from pilot.server.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
return OldLLMModelAdaperWrapper(
_old_get_llm_model_adapter(model_name, model_path),
get_llm_chat_adapter(model_name, model_path),
)
def _get_fastchat_model_adapter(
model_name: str,
model_path: str,
caller: Callable[[str], None] = None,
use_fastchat_monkey_patch: bool = False,
):
from fastchat.model import model_adapter
_bak_get_model_adapter = model_adapter.get_model_adapter
try:
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
if caller:
return caller(model_path)
finally:
del thread_local.model_name
model_adapter.get_model_adapter = _bak_get_model_adapter
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
if not model_name:
if not hasattr(thread_local, "model_name"):
raise RuntimeError("fastchat get adapter monkey path need model_name")
model_name = thread_local.model_name
from fastchat.model.model_adapter import model_adapters
for adapter in model_adapters:
if adapter.match(model_name):
logger.info(
f"Found llm model adapter with model name: {model_name}, {adapter}"
)
return adapter
model_path_basename = (
None if not model_path else os.path.basename(os.path.normpath(model_path))
)
for adapter in model_adapters:
if model_path_basename and adapter.match(model_path_basename):
logger.info(
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
)
return adapter
for adapter in model_adapters:
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
return adapter
raise ValueError(
f"Invalid model adapter for model name {model_name} and model path {model_path}"
)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]
class VLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping vllm engine"""
def model_type(self) -> str:
return ModelType.VLLM
def model_param_class(self, model_type: str = None) -> ModelParameters:
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument(
"--model_path",
type=str,
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
type=str,
default=None,
help="Prompt template. If None, the prompt template is automatically determined from model path",
)
descs = _extract_parameter_details(
parser,
"pilot.model.parameter.VLLMModelParameters",
skip_names=["model"],
overwrite_default_values={"trust_remote_code": True},
)
return _build_parameter_class(descs)
def load_from_params(self, params):
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
setattr(params, "tensor_parallel_size", num_gpus)
logger.info(
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
)
params = dataclasses.asdict(params)
params["model"] = params["model_path"]
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
def support_async(self) -> bool:
return True
def get_async_generate_stream_function(self, model, model_path: str):
from pilot.model.llm_out.vllm_llm import generate_stream
return generate_stream
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return _auto_get_conv_template(model_name, model_path)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
register_conv_template(
Conversation(
name="internlm-chat",
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
# TODO feedback stop_str to fastchat
stop_str="<eoa>",
),
override=True,
)

View File

@ -64,6 +64,13 @@ class ModelWorkerParameters(BaseModelParameters):
default=None, default=None,
metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"}, metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
) )
model_type: Optional[str] = field(
default="huggingface",
metadata={
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed",
},
)
host: Optional[str] = field( host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"} default="0.0.0.0", metadata={"help": "Model worker deploy host"}
) )
@ -163,7 +170,7 @@ class ModelParameters(BaseModelParameters):
model_type: Optional[str] = field( model_type: Optional[str] = field(
default="huggingface", default="huggingface",
metadata={ metadata={
"help": "Model type, huggingface, llama.cpp and proxy", "help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed", "tags": "fixed",
}, },
) )
@ -292,7 +299,7 @@ class ProxyModelParameters(BaseModelParameters):
model_type: Optional[str] = field( model_type: Optional[str] = field(
default="proxy", default="proxy",
metadata={ metadata={
"help": "Model type, huggingface, llama.cpp and proxy", "help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed", "tags": "fixed",
}, },
) )

View File

@ -140,7 +140,7 @@ class BaseChat(ABC):
payload = self.__call_base() payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11 self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}") logger.info(f"Request: \n{payload}")
ai_response_text = "" ai_response_text = ""
try: try:
from pilot.model.cluster import WorkerManagerFactory from pilot.model.cluster import WorkerManagerFactory

View File

@ -19,7 +19,7 @@ def signal_handler(sig, frame):
os._exit(0) os._exit(0)
def async_db_summery(system_app: SystemApp): def async_db_summary(system_app: SystemApp):
from pilot.summary.db_summary_client import DBSummaryClient from pilot.summary.db_summary_client import DBSummaryClient
client = DBSummaryClient(system_app=system_app) client = DBSummaryClient(system_app=system_app)
@ -79,7 +79,7 @@ def _create_model_start_listener(system_app: SystemApp):
print("begin run _add_app_startup_event") print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app) conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summery(system_app) async_db_summary(system_app)
return startup_event return startup_event

View File

@ -1,3 +1,7 @@
"""
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
#!/usr/bin/env python3 #!/usr/bin/env python3
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-

View File

View File

@ -1,18 +1,18 @@
class DBSummary: class DBSummary:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.summery = None self.summary = None
self.tables = [] self.tables = []
self.metadata = str self.metadata = str
def get_summery(self): def get_summary(self):
return self.summery return self.summary
class TableSummary: class TableSummary:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.summery = None self.summary = None
self.fields = [] self.fields = []
self.indexes = [] self.indexes = []
@ -20,12 +20,12 @@ class TableSummary:
class FieldSummary: class FieldSummary:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.summery = None self.summary = None
self.data_type = None self.data_type = None
class IndexSummary: class IndexSummary:
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
self.summery = None self.summary = None
self.bind_fields = [] self.bind_fields = []

View File

@ -47,13 +47,13 @@ class DBSummaryClient:
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=db_summary_client.get_summery(), file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
self.init_db_profile(db_summary_client, dbname, embeddings) self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist(): if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery(): for vector_table_info in db_summary_client.get_summary():
embedding = StringEmbedding( embedding = StringEmbedding(
vector_table_info, vector_table_info,
vector_store_config, vector_store_config,
@ -61,7 +61,7 @@ class DBSummaryClient:
embedding.source_embedding() embedding.source_embedding()
else: else:
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=db_summary_client.get_summery(), file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
) )
embedding.source_embedding() embedding.source_embedding()
@ -144,8 +144,8 @@ class DBSummaryClient:
vector_store_config=vector_store_config, vector_store_config=vector_store_config,
embedding_factory=embedding_factory, embedding_factory=embedding_factory,
) )
table_summery = knowledge_embedding_client.similar_search(query, 1) table_summary = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content) related_table_summaries.append(table_summary[0].page_content)
return related_table_summaries return related_table_summaries
def init_db_summary(self): def init_db_summary(self):
@ -169,7 +169,7 @@ class DBSummaryClient:
"embeddings": embeddings, "embeddings": embeddings,
} }
embedding = StringEmbedding( embedding = StringEmbedding(
file_path=db_summary_client.get_db_summery(), file_path=db_summary_client.get_db_summary(),
vector_store_config=profile_store_config, vector_store_config=profile_store_config,
) )
if not embedding.vector_name_exist(): if not embedding.vector_name_exist():

View File

@ -12,7 +12,7 @@ class RdbmsSummary(DBSummary):
def __init__(self, name, type): def __init__(self, name, type):
self.name = name self.name = name
self.type = type self.type = type
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {} self.tables = {}
self.tables_info = [] self.tables_info = []
self.vector_tables_info = [] self.vector_tables_info = []
@ -48,7 +48,7 @@ class RdbmsSummary(DBSummary):
for table_name in tables: for table_name in tables:
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map) table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summery() # self.tables[table_name] = table_summary.get_summary()
self.tables[table_name] = table_summary.get_columns() self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns()) self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json()) # self.table_columns_json.append(table_summary.get_summary_json())
@ -59,18 +59,18 @@ class RdbmsSummary(DBSummary):
) )
) )
self.table_columns_json.append(table_profile) self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery()) # self.tables_info.append(table_summary.get_summary())
def get_summery(self): def get_summary(self):
if CFG.SUMMARY_CONFIG == "FAST": if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info return self.vector_tables_info
else: else:
return self.summery.format( return self.summary.format(
name=self.name, type=self.type, table_info=";".join(self.tables_info) name=self.name, type=self.type, table_info=";".join(self.tables_info)
) )
def get_db_summery(self): def get_db_summary(self):
return self.summery.format( return self.summary.format(
name=self.name, name=self.name,
type=self.type, type=self.type,
tables=";".join(self.vector_tables_info), tables=";".join(self.vector_tables_info),
@ -94,8 +94,8 @@ class RdbmsTableSummary(TableSummary):
def __init__(self, instance, dbname, name, comment_map): def __init__(self, instance, dbname, name, comment_map):
self.name = name self.name = name
self.dbname = dbname self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = [] self.fields = []
self.fields_info = [] self.fields_info = []
self.indexes = [] self.indexes = []
@ -107,19 +107,19 @@ class RdbmsTableSummary(TableSummary):
for field in fields: for field in fields:
field_summary = RdbmsFieldsSummary(field) field_summary = RdbmsFieldsSummary(field)
self.fields.append(field_summary) self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery()) self.fields_info.append(field_summary.get_summary())
field_names.append(field[0]) field_names.append(field[0])
self.column_summery = """{name}({columns_info})""".format( self.column_summary = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names) name=name, columns_info=",".join(field_names)
) )
for index in indexes: for index in indexes:
index_summary = RdbmsIndexSummary(index) index_summary = RdbmsIndexSummary(index)
self.indexes.append(index_summary) self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery()) self.indexes_info.append(index_summary.get_summary())
self.json_summery = self.json_summery_template.format( self.json_summary = self.json_summary_template.format(
name=name, name=name,
comment=comment_map[name], comment=comment_map[name],
fields=self.fields_info, fields=self.fields_info,
@ -128,8 +128,8 @@ class RdbmsTableSummary(TableSummary):
rows=1000, rows=1000,
) )
def get_summery(self): def get_summary(self):
return self.summery.format( return self.summary.format(
name=self.name, name=self.name,
dbname=self.dbname, dbname=self.dbname,
fields=";".join(self.fields_info), fields=";".join(self.fields_info),
@ -137,10 +137,10 @@ class RdbmsTableSummary(TableSummary):
) )
def get_columns(self): def get_columns(self):
return self.column_summery return self.column_summary
def get_summary_json(self): def get_summary_json(self):
return self.json_summery return self.json_summary
class RdbmsFieldsSummary(FieldSummary): class RdbmsFieldsSummary(FieldSummary):
@ -148,14 +148,14 @@ class RdbmsFieldsSummary(FieldSummary):
def __init__(self, field): def __init__(self, field):
self.name = field[0] self.name = field[0]
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ # self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" # self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1] self.data_type = field[1]
self.default_value = field[2] self.default_value = field[2]
self.is_nullable = field[3] self.is_nullable = field[3]
self.comment = field[4] self.comment = field[4]
def get_summery(self): def get_summary(self):
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format( return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name, name=self.name,
data_type=self.data_type, data_type=self.data_type,
@ -170,11 +170,11 @@ class RdbmsIndexSummary(IndexSummary):
def __init__(self, index): def __init__(self, index):
self.name = index[0] self.name = index[0]
# self.summery = """index name:{name}, index bind columns:{bind_fields}""" # self.summary = """index name:{name}, index bind columns:{bind_fields}"""
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}' self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1] self.bind_fields = index[1]
def get_summery(self): def get_summary(self):
return self.summery_template.format( return self.summary_template.format(
name=self.name, bind_fields=self.bind_fields name=self.name, bind_fields=self.bind_fields
) )

View File

@ -86,17 +86,7 @@ class BaseParameters:
return updated return updated
def __str__(self) -> str: def __str__(self) -> str:
class_name = self.__class__.__name__ return _get_dataclass_print_str(self)
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = _get_simple_privacy_field_value(self, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
def to_command_args(self, args_prefix: str = "--") -> List[str]: def to_command_args(self, args_prefix: str = "--") -> List[str]:
"""Convert the fields of the dataclass to a list of command line arguments. """Convert the fields of the dataclass to a list of command line arguments.
@ -110,6 +100,20 @@ class BaseParameters:
return _dict_to_command_args(asdict(self), args_prefix=args_prefix) return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
def _get_dataclass_print_str(obj):
class_name = obj.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(obj):
value = _get_simple_privacy_field_value(obj, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]: def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
"""Convert dict to a list of command line arguments """Convert dict to a list of command line arguments
@ -493,9 +497,10 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
if not desc: if not desc:
raise ValueError("Parameter descriptions cant be empty") raise ValueError("Parameter descriptions cant be empty")
param_class_str = desc[0].param_class param_class_str = desc[0].param_class
param_class = import_from_string(param_class_str, ignore_import_error=True) if param_class_str:
if param_class: param_class = import_from_string(param_class_str, ignore_import_error=True)
return param_class if param_class:
return param_class
module_name, _, class_name = param_class_str.rpartition(".") module_name, _, class_name = param_class_str.rpartition(".")
fields_dict = {} # This will store field names and their default values or field() fields_dict = {} # This will store field names and their default values or field()
@ -520,6 +525,71 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
return result_class return result_class
def _extract_parameter_details(
parser: argparse.ArgumentParser,
param_class: str = None,
skip_names: List[str] = None,
overwrite_default_values: Dict = {},
) -> List[ParameterDescription]:
descriptions = []
for action in parser._actions:
if (
action.default == argparse.SUPPRESS
): # typically this means the argument was not provided
continue
# determine parameter class (store_true/store_false are flags)
flag_or_option = (
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
)
# extract parameter name (use the first option string, typically the long form)
param_name = action.option_strings[0] if action.option_strings else action.dest
if param_name.startswith("--"):
param_name = param_name[2:]
if param_name.startswith("-"):
param_name = param_name[1:]
param_name = param_name.replace("-", "_")
if skip_names and param_name in skip_names:
continue
# gather other details
default_value = action.default
if param_name in overwrite_default_values:
default_value = overwrite_default_values[param_name]
arg_type = (
action.type if not callable(action.type) else str(action.type.__name__)
)
description = action.help
# determine if the argument is required
required = action.required
# extract valid values for choices, if provided
valid_values = action.choices if action.choices is not None else None
# set ext_metadata as an empty dict for now, can be updated later if needed
ext_metadata = {}
descriptions.append(
ParameterDescription(
param_class=param_class,
param_name=param_name,
param_type=arg_type,
default_value=default_value,
description=description,
required=required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
)
return descriptions
class _SimpleArgParser: class _SimpleArgParser:
def __init__(self, *args): def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args} self.params = {arg.replace("_", "-"): None for arg in args}

View File

View File

@ -0,0 +1,81 @@
import argparse
import pytest
from pilot.utils.parameter_utils import _extract_parameter_details
def create_parser():
parser = argparse.ArgumentParser()
return parser
@pytest.mark.parametrize(
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
[
("--option", "option", "value", str, "str", "An option argument"),
("-option", "option", "value", str, "str", "An option argument"),
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
],
)
def test_extract_parameter_details_option_argument(
argument,
expected_param_name,
default_value,
param_type,
expected_param_type,
description,
):
parser = create_parser()
parser.add_argument(
argument, default=default_value, type=param_type, help=description
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == expected_param_name
assert desc.param_type == expected_param_type
assert desc.default_value == default_value
assert desc.description == description
assert desc.required == False
assert desc.valid_values is None
def test_extract_parameter_details_flag_argument():
parser = create_parser()
parser.add_argument("--flag", action="store_true", help="A flag argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "flag"
assert desc.description == "A flag argument"
assert desc.required == False
def test_extract_parameter_details_choice_argument():
parser = create_parser()
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "choice"
assert desc.valid_values == ["A", "B", "C"]
def test_extract_parameter_details_required_argument():
parser = create_parser()
parser.add_argument(
"--required", required=True, type=int, help="A required argument"
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "required"
assert desc.required == True

View File

@ -9,4 +9,6 @@ pytest-mock
pytest-recording pytest-recording
pytesseract==0.3.10 pytesseract==0.3.10
# python code format # python code format
black black
# for git hooks
pre-commmit

View File

@ -203,9 +203,9 @@ def get_cuda_version() -> str:
def torch_requires( def torch_requires(
torch_version: str = "2.0.0", torch_version: str = "2.0.1",
torchvision_version: str = "0.15.1", torchvision_version: str = "0.15.2",
torchaudio_version: str = "2.0.1", torchaudio_version: str = "2.0.2",
): ):
torch_pkgs = [ torch_pkgs = [
f"torch=={torch_version}", f"torch=={torch_version}",
@ -298,6 +298,7 @@ def core_requires():
] ]
setup_spec.extras["framework"] = [ setup_spec.extras["framework"] = [
"fschat",
"coloredlogs", "coloredlogs",
"httpx", "httpx",
"sqlparse==0.4.4", "sqlparse==0.4.4",
@ -396,12 +397,19 @@ def gpt4all_requires():
setup_spec.extras["gpt4all"] = ["gpt4all"] setup_spec.extras["gpt4all"] = ["gpt4all"]
def vllm_requires():
"""
pip install "db-gpt[vllm]"
"""
setup_spec.extras["vllm"] = ["vllm"]
def default_requires(): def default_requires():
""" """
pip install "db-gpt[default]" pip install "db-gpt[default]"
""" """
setup_spec.extras["default"] = [ setup_spec.extras["default"] = [
"tokenizers==0.13.2", "tokenizers==0.13.3",
"accelerate>=0.20.3", "accelerate>=0.20.3",
"sentence-transformers", "sentence-transformers",
"protobuf==3.20.3", "protobuf==3.20.3",
@ -435,6 +443,7 @@ all_vector_store_requires()
all_datasource_requires() all_datasource_requires()
openai_requires() openai_requires()
gpt4all_requires() gpt4all_requires()
vllm_requires()
# must be last # must be last
default_requires() default_requires()