Merge branch 'main' into tt_dev

This commit is contained in:
csunny 2023-10-12 09:23:06 +08:00
commit eef2b381aa
39 changed files with 1001 additions and 197 deletions

4
.gitignore vendored
View File

@ -151,4 +151,6 @@ pilot/mock_datas/db-gpt-test.db.wal
logswebserver.log.*
.history/*
.plugin_env
.plugin_env
# Ignore for now
thirdparty

View File

@ -10,11 +10,11 @@ git clone https://github.com/<YOUR-GITHUB-USERNAME>/DB-GPT
```
3. Install the project requirements
```
pip install -r requirements.txt
pip install -r requirements/dev-requirements.txt
```
4. Install pre-commit hooks
```
pre-commit install
pre-commit install --allow-missing-config
```
5. Create a new branch for your changes using the following command:

View File

@ -41,6 +41,7 @@ services:
restart: unless-stopped
networks:
- dbgptnet
ipc: host
deploy:
resources:
reservations:

View File

@ -1,6 +1,6 @@
#!/bin/bash
docker run --gpus all -d \
docker run --ipc host --gpus all -d \
-p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \

View File

@ -4,7 +4,7 @@
PROXY_API_KEY="$PROXY_API_KEY"
PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}"
docker run --gpus all -d \
docker run --ipc host --gpus all -d \
-p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \

View File

@ -21,6 +21,7 @@ services:
restart: unless-stopped
networks:
- dbgptnet
ipc: host
deploy:
resources:
reservations:

View File

@ -47,7 +47,7 @@ You can execute the command `bash docker/build_all_images.sh --help` to see more
**Run with local model and SQLite database**
```bash
docker run --gpus all -d \
docker run --ipc host --gpus all -d \
-p 5000:5000 \
-e LOCAL_DB_TYPE=sqlite \
-e LOCAL_DB_PATH=data/default_sqlite.db \
@ -73,7 +73,7 @@ docker logs dbgpt -f
**Run with local model and MySQL database**
```bash
docker run --gpus all -d -p 3306:3306 \
docker run --ipc host --gpus all -d -p 3306:3306 \
-p 5000:5000 \
-e LOCAL_DB_HOST=127.0.0.1 \
-e LOCAL_DB_PASSWORD=aa123456 \

View File

@ -30,3 +30,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
./llama/llama_cpp.md
./quantization/quantization.md
./vllm/vllm.md

View File

@ -0,0 +1,26 @@
vLLM
==================================
[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving.
## Running vLLM
### Installing Dependencies
vLLM is an optional dependency in DB-GPT, and you can manually install it using the following command:
```bash
pip install -e ".[vllm]"
```
### Modifying the Configuration File
Next, you can directly modify your `.env` file to enable vllm.
```env
LLM_MODEL=vicuna-13b-v1.5
MODEL_TYPE=vllm
```
You can view the models supported by vLLM [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html#supported-models)
Then you can run it according to [Run](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run).

View File

@ -0,0 +1,79 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) 2023, csunny
# This file is distributed under the same license as the DB-GPT package.
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: DB-GPT 👏👏 0.3.9\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2023-10-09 19:46+0800\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language: zh_CN\n"
"Language-Team: zh_CN <LL@li.org>\n"
"Plural-Forms: nplurals=1; plural=0;\n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Generated-By: Babel 2.12.1\n"
#: ../../getting_started/install/llm/vllm/vllm.md:1
#: 9193438ba52148a3b71f190beeb4ef42
msgid "vLLM"
msgstr ""
#: ../../getting_started/install/llm/vllm/vllm.md:4
#: c30d032965794d7e81636581324be45d
msgid ""
"[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use "
"library for LLM inference and serving."
msgstr "[vLLM](https://github.com/vllm-project/vllm) 是一个快速且易于使用的 LLM 推理和服务的库。"
#: ../../getting_started/install/llm/vllm/vllm.md:6
#: b399c7268e0448cb893fbcb11a480849
msgid "Running vLLM"
msgstr "运行 vLLM"
#: ../../getting_started/install/llm/vllm/vllm.md:8
#: 7bed52b8bac946069a24df9e94098df5
msgid "Installing Dependencies"
msgstr "安装依赖"
#: ../../getting_started/install/llm/vllm/vllm.md:10
#: fd50a9f3e1b1459daa3b1a0cd610d1a3
msgid ""
"vLLM is an optional dependency in DB-GPT, and you can manually install it"
" using the following command:"
msgstr "vLLM 在 DB-GPT 是一个可选依赖, 你可以使用下面的命令手动安装它:"
#: ../../getting_started/install/llm/vllm/vllm.md:16
#: 44b251bc6b2c41ebaad9fd5a6a204c7c
msgid "Modifying the Configuration File"
msgstr "修改配置文件"
#: ../../getting_started/install/llm/vllm/vllm.md:18
#: 37f4e65148fa4339969265107b70b8fe
msgid "Next, you can directly modify your `.env` file to enable vllm."
msgstr "你可以直接修改你的 `.env` 文件。"
#: ../../getting_started/install/llm/vllm/vllm.md:24
#: 15d79c9417d04e779fa00a08a05e30d7
msgid ""
"You can view the models supported by vLLM "
"[here](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
"#supported-models)"
msgstr ""
"你可以在 "
"[这里](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
"#supported-models) 查看 vLLM 支持的模型。"
#: ../../getting_started/install/llm/vllm/vllm.md:26
#: 28d90b1fdf6943d9969c8668a7c1094b
msgid ""
"Then you can run it according to [Run](https://db-"
"gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run)."
msgstr ""
"然后你可以根据[运行]"
"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/getting_started/install/deploy/deploy.html#run)来启动项目。"

View File

@ -204,6 +204,9 @@ class Config(metaclass=Singleton):
self.SYSTEM_APP: Optional["SystemApp"] = None
### Temporary configuration
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
def set_debug_mode(self, value: bool) -> None:
"""Set the debug mode value"""
self.debug_mode = value

View File

@ -78,6 +78,8 @@ LLM_MODEL_CONFIG = {
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
# For test now
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
}
EMBEDDING_MODEL_CONFIG = {

View File

@ -29,7 +29,7 @@ class SeparatorStyle(Enum):
@dataclasses.dataclass
class Conversation:
class OldConversation:
"""This class keeps all conversation history."""
system: str
@ -81,7 +81,7 @@ class Conversation:
return ret
def copy(self):
return Conversation(
return OldConversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
@ -104,7 +104,7 @@ class Conversation:
}
conv_default = Conversation(
conv_default = OldConversation(
system=None,
roles=("human", "ai"),
messages=[],
@ -148,7 +148,7 @@ conv_default = Conversation(
# )
conv_one_shot = Conversation(
conv_one_shot = OldConversation(
system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:",
roles=("USER", "ASSISTANT"),
messages=(
@ -179,7 +179,7 @@ conv_one_shot = Conversation(
sep="###",
)
conv_vicuna_v1 = Conversation(
conv_vicuna_v1 = OldConversation(
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
roles=("USER", "ASSISTANT"),
@ -190,7 +190,7 @@ conv_vicuna_v1 = Conversation(
sep2="</s>",
)
auto_dbgpt_one_shot = Conversation(
auto_dbgpt_one_shot = OldConversation(
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),
@ -263,7 +263,7 @@ auto_dbgpt_one_shot = Conversation(
sep="###",
)
auto_dbgpt_without_shot = Conversation(
auto_dbgpt_without_shot = OldConversation(
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
roles=("USER", "ASSISTANT"),

View File

@ -1,3 +1,7 @@
"""
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
@ -13,6 +17,8 @@ from transformers import (
AutoTokenizer,
LlamaTokenizer,
)
from pilot.model.base import ModelType
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
@ -26,15 +32,6 @@ logger = logging.getLogger(__name__)
CFG = Config()
class ModelType:
""" "Type of model"""
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
PROXY = "proxy"
# TODO, support more model type
class BaseLLMAdaper:
"""The Base class for multi model, in our project.
We will support those model, which performance resemble ChatGPT"""
@ -95,33 +92,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
if model_name is None:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()
return [param_class]
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
try:
llm_adapter = get_llm_model_adapter(model_name, model_path)

View File

@ -15,6 +15,16 @@ class Message(TypedDict):
content: str
class ModelType:
""" "Type of model"""
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
PROXY = "proxy"
VLLM = "vllm"
# TODO, support more model type
@dataclass
class ModelInstance:
"""Model instance info"""

View File

@ -404,7 +404,7 @@ def stop_model_controller(port: int):
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
from pilot.model.adapter import _dynamic_model_parser
from pilot.model.model_adapter import _dynamic_model_parser
param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters]

View File

@ -1,26 +1,34 @@
import os
import logging
from typing import Dict, Iterator, List
from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.base import ModelOutput
from pilot.model.loader import ModelLoader, _get_model_real_path
from pilot.model.parameter import ModelParameters
from pilot.model.cluster.worker_base import ModelWorker
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
from pilot.utils.model_utils import _clear_model_cache
from pilot.utils.parameter_utils import EnvArgumentParser
logger = logging.getLogger(__name__)
_torch_imported = False
try:
import torch
_torch_imported = True
except ImportError:
pass
class DefaultModelWorker(ModelWorker):
def __init__(self) -> None:
self.model = None
self.tokenizer = None
self._model_params = None
self.llm_adapter: BaseLLMAdaper = None
self.llm_chat_adapter: BaseChatAdpter = None
self.llm_adapter: LLMModelAdaper = None
self._support_async = False
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
if model_path.endswith("/"):
@ -29,18 +37,24 @@ class DefaultModelWorker(ModelWorker):
self.model_name = model_name
self.model_path = model_path
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = kwargs.get("model_type")
### Temporary configuration, fastchat will be used by default in the future.
use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true"
self.llm_adapter = get_llm_model_adapter(
self.model_name,
self.model_path,
use_fastchat=use_fastchat,
model_type=model_type,
)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
self._support_async = self.llm_adapter.support_async()
logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
)
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
self.model_path
)
self.ml: ModelLoader = ModelLoader(
model_path=self.model_path, model_name=self.model_name
)
@ -50,6 +64,9 @@ class DefaultModelWorker(ModelWorker):
def model_param_class(self) -> ModelParameters:
return self.param_cls
def support_async(self) -> bool:
return self._support_async
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
param_cls = self.model_param_class()
model_args = EnvArgumentParser()
@ -77,7 +94,9 @@ class DefaultModelWorker(ModelWorker):
model_params = self.parse_parameters(command_args)
self._model_params = model_params
logger.info(f"Begin load model, model params: {model_params}")
self.model, self.tokenizer = self.ml.loader_with_params(model_params)
self.model, self.tokenizer = self.ml.loader_with_params(
model_params, self.llm_adapter
)
def stop(self) -> None:
if not self.model:
@ -90,51 +109,26 @@ class DefaultModelWorker(ModelWorker):
_clear_model_cache(self._model_params.device)
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
torch_imported = False
try:
import torch
torch_imported = True
except ImportError:
pass
try:
# params adaptation
params, model_context = self.llm_chat_adapter.model_adaptation(
params, self.ml.model_path, prompt_template=self.ml.prompt_template
params, model_context, generate_stream_func = self._prepare_generate_stream(
params
)
previous_response = ""
print("stream output:\n")
for output in self.generate_stream_func(
for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
# Please do not open the output in production!
# The gpt4all thread shares stdout with the parent process,
# and opening it may affect the frontend output.
incremental_output = output[len(previous_response) :]
# print("output: ", output)
print(incremental_output, end="", flush=True)
previous_response = output
# return some model context to dgt-server
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
)
previous_response = output_str
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
except Exception as e:
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
yield model_output
yield self._handle_exception(e)
def generate(self, params: Dict) -> ModelOutput:
"""Generate non stream result"""
@ -145,3 +139,81 @@ class DefaultModelWorker(ModelWorker):
def embeddings(self, params: Dict) -> List[List[float]]:
raise NotImplementedError
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
try:
params, model_context, generate_stream_func = self._prepare_generate_stream(
params
)
previous_response = ""
async for output in generate_stream_func(
self.model, self.tokenizer, params, get_device(), self.context_len
):
model_output, incremental_output, output_str = self._handle_output(
output, previous_response, model_context
)
previous_response = output_str
yield model_output
print(
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
)
except Exception as e:
yield self._handle_exception(e)
async def async_generate(self, params: Dict) -> ModelOutput:
output = None
async for out in self.async_generate_stream(params):
output = out
return output
def _prepare_generate_stream(self, params: Dict):
params, model_context = self.llm_adapter.model_adaptation(
params,
self.model_name,
self.model_path,
prompt_template=self.ml.prompt_template,
)
stream_type = ""
if self.support_async():
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
self.model, self.model_path
)
stream_type = "async "
logger.info(
"current generate stream function is asynchronous stream function"
)
else:
generate_stream_func = self.llm_adapter.get_generate_stream_function(
self.model, self.model_path
)
str_prompt = params.get("prompt")
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
return params, model_context, generate_stream_func
def _handle_output(self, output, previous_response, model_context):
if isinstance(output, dict):
finish_reason = output.get("finish_reason")
output = output["text"]
if finish_reason is not None:
logger.info(f"finish_reason: {finish_reason}")
incremental_output = output[len(previous_response) :]
print(incremental_output, end="", flush=True)
model_output = ModelOutput(
text=output, error_code=0, model_context=model_context
)
return model_output, incremental_output, output
def _handle_exception(self, e):
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
if _torch_imported and isinstance(e, torch.cuda.CudaError):
model_output = ModelOutput(
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
)
else:
model_output = ModelOutput(
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
error_code=0,
)
return model_output

View File

@ -3,7 +3,9 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conver
Conversation prompt templates.
TODO Using fastchat core package
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
import dataclasses

View File

@ -1,6 +1,8 @@
"""
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

@ -7,7 +7,6 @@ import time
from typing import Optional
from pilot.configs.config import Config
from pilot.conversation import Conversation
# TODO Rewrite this
@ -44,36 +43,6 @@ def retry_stream_api(
return _wrapper
# Overly simple abstraction util we create something better
# simple retry mechanism when getting a rate error or a bad gateway
def create_chat_competion(
conv: Conversation,
model: Optional[str] = None,
temperature: float = None,
max_new_tokens: Optional[int] = None,
) -> str:
"""Create a chat completion using the Vicuna-13b
Args:
messages(List[Message]): The messages to send to the chat completion
model (str, optional): The model to use. Default to None.
temperature (float, optional): The temperature to use. Defaults to 0.7.
max_tokens (int, optional): The max tokens to use. Defaults to None.
Returns:
str: The response from the chat completion
"""
cfg = Config()
if temperature is None:
temperature = cfg.temperature
# TODO request vicuna model get response
# convert vicuna message to chat completion.
for plugin in cfg.plugins:
if plugin.can_handle_chat_completion():
pass
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:

View File

@ -1,7 +1,6 @@
import torch
from threading import Thread
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):

View File

@ -0,0 +1,56 @@
from typing import Dict
from vllm import AsyncLLMEngine
from vllm.utils import random_uuid
from vllm.sampling_params import SamplingParams
async def generate_stream(
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
):
"""
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
"""
prompt = params["prompt"]
request_id = params.pop("request_id") if "request_id" in params else random_uuid()
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
max_new_tokens = int(params.get("max_new_tokens", 2048))
echo = bool(params.get("echo", True))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
if tokenizer.eos_token_id is not None:
stop_token_ids.append(tokenizer.eos_token_id)
# Handle stop_str
stop = set()
if isinstance(stop_str, str) and stop_str != "":
stop.add(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.update(stop_str)
for tid in stop_token_ids:
if tid is not None:
stop.add(tokenizer.decode(tid))
# make sampling params in vllm
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
use_beam_search=False,
stop=list(stop),
max_tokens=max_new_tokens,
)
results_generator = model.generate(prompt, sampling_params, request_id)
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [prompt + output.text for output in request_output.outputs]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
yield {"text": text_outputs, "error_code": 0, "usage": {}}

View File

@ -146,7 +146,7 @@ def list_supported_models():
def _list_supported_models(
worker_type: str, model_config: Dict[str, str]
) -> List[SupportedModel]:
from pilot.model.adapter import get_llm_model_adapter
from pilot.model.model_adapter import get_llm_model_adapter
from pilot.model.parameter import ModelParameters
from pilot.model.loader import _get_model_real_path

View File

@ -3,9 +3,11 @@
from typing import Optional, Dict
from dataclasses import asdict
import logging
from pilot.configs.model_config import get_device
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.base import ModelType
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
@ -114,8 +116,9 @@ class ModelLoader:
else:
raise Exception(f"Unkown model type {model_type}")
def loader_with_params(self, model_params: ModelParameters):
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
def loader_with_params(
self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
):
model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template
if model_type == ModelType.HF:
@ -124,11 +127,13 @@ class ModelLoader:
return llamacpp_loader(llm_adapter, model_params)
elif model_type == ModelType.PROXY:
return proxyllm_loader(llm_adapter, model_params)
elif model_type == ModelType.VLLM:
return llm_adapter.load_from_params(model_params)
else:
raise Exception(f"Unkown model type {model_type}")
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
import torch
from pilot.model.compression import compress_module
@ -181,7 +186,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
f"Current model {model_params.model_name} not supported quantization"
)
# default loader
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs)
model, tokenizer = llm_adapter.load(model_params.model_path, kwargs)
if model_params.load_8bit and num_gpus == 1 and tokenizer:
# TODO merge current code into `load_huggingface_quantization_model`
@ -204,7 +209,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
def load_huggingface_quantization_model(
llm_adapter: BaseLLMAdaper,
llm_adapter: LLMModelAdaper,
model_params: ModelParameters,
kwargs: Dict,
max_memory: Dict[int, str],
@ -339,7 +344,7 @@ def load_huggingface_quantization_model(
return model, tokenizer
def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters):
def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
try:
from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel
except ImportError as exc:
@ -353,7 +358,7 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
return model, tokenizer
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
from pilot.model.proxy.llms.proxy_model import ProxyModel
logger.info("Load proxyllm")

View File

@ -0,0 +1,431 @@
from __future__ import annotations
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
import dataclasses
import logging
import threading
import os
from functools import cache
from pilot.model.base import ModelType
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
from pilot.utils.parameter_utils import (
_extract_parameter_details,
_build_parameter_class,
_get_dataclass_print_str,
)
try:
from fastchat.conversation import (
Conversation,
register_conv_template,
SeparatorStyle,
)
except ImportError as exc:
raise ValueError(
"Could not import python package: fschat "
"Please install fastchat by command `pip install fschat` "
) from exc
if TYPE_CHECKING:
from fastchat.model.model_adapter import BaseModelAdapter
from pilot.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
from torch.nn import Module as TorchNNModule
logger = logging.getLogger(__name__)
thread_local = threading.local()
_OLD_MODELS = [
"llama-cpp",
"proxyllm",
"gptj-6b",
]
class LLMModelAdaper:
"""New Adapter for DB-GPT LLM models"""
def model_type(self) -> str:
return ModelType.HF
def model_param_class(self, model_type: str = None) -> ModelParameters:
"""Get the startup parameters instance of the model"""
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def load(self, model_path: str, from_pretrained_kwargs: dict):
"""Load model and tokenizer"""
raise NotImplementedError
def load_from_params(self, params):
"""Load the model and tokenizer according to the given parameters"""
raise NotImplementedError
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""
return False
def get_generate_stream_function(self, model, model_path: str):
"""Get the generate stream function of the model"""
raise NotImplementedError
def get_async_generate_stream_function(self, model, model_path: str):
"""Get the asynchronous generate stream function of the model"""
raise NotImplementedError
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
"""Get the default conv template"""
raise NotImplementedError
def model_adaptation(
self,
params: Dict,
model_name: str,
model_path: str,
prompt_template: str = None,
) -> Tuple[Dict, Dict]:
"""Params adaptation"""
conv = self.get_default_conv_template(model_name, model_path)
messages = params.get("messages")
# Some model scontext to dbgpt server
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
if messages:
# Dict message to ModelMessage
messages = [
m if isinstance(m, ModelMessage) else ModelMessage(**m)
for m in messages
]
params["messages"] = messages
if prompt_template:
logger.info(f"Use prompt template {prompt_template} from config")
conv = get_conv_template(prompt_template)
if not conv or not messages:
# Nothing to do
logger.info(
f"No conv from model_path {model_path} or no messages in params, {self}"
)
return params, model_context
conv = conv.copy()
system_messages = []
for message in messages:
role, content = None, None
if isinstance(message, ModelMessage):
role = message.role
content = message.content
elif isinstance(message, dict):
role = message["role"]
content = message["content"]
else:
raise ValueError(f"Invalid message type: {message}")
if role == ModelMessageRoleType.SYSTEM:
# Support for multiple system messages
system_messages.append(content)
elif role == ModelMessageRoleType.HUMAN:
conv.append_message(conv.roles[0], content)
elif role == ModelMessageRoleType.AI:
conv.append_message(conv.roles[1], content)
else:
raise ValueError(f"Unknown role: {role}")
if system_messages:
conv.set_system_message("".join(system_messages))
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
new_prompt = conv.get_prompt()
# Overwrite the original prompt
# TODO remote bos token and eos token from tokenizer_config.json of model
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
model_context["prompt_echo_len_char"] = prompt_echo_len_char
model_context["echo"] = params.get("echo", True)
model_context["has_format_prompt"] = True
params["prompt"] = new_prompt
# Overwrite model params:
params["stop"] = conv.stop_str
return params, model_context
class OldLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping old adapter, which may be removed later"""
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
self._adapter = adapter
self._chat_adapter = chat_adapter
def model_type(self) -> str:
return self._adapter.model_type()
def model_param_class(self, model_type: str = None) -> ModelParameters:
return self._adapter.model_param_class(model_type)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._chat_adapter.get_conv_template(model_path)
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.loader(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model, model_path: str):
return self._chat_adapter.get_generate_stream_func(model_path)
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping fastchat adapter"""
def __init__(self, adapter: "BaseModelAdapter") -> None:
self._adapter = adapter
def load(self, model_path: str, from_pretrained_kwargs: dict):
return self._adapter.load_model(model_path, from_pretrained_kwargs)
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
from fastchat.model.model_adapter import get_generate_stream_function
return get_generate_stream_function(model, model_path)
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return self._adapter.get_default_conv_template(model_path)
def get_conv_template(name: str) -> "Conversation":
"""Get a conversation template."""
from fastchat.conversation import get_conv_template
return get_conv_template(name)
@cache
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
try:
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
return adapter.get_default_conv_template(model_name, model_path)
except Exception:
return None
@cache
def get_llm_model_adapter(
model_name: str,
model_path: str,
use_fastchat: bool = True,
use_fastchat_monkey_patch: bool = False,
model_type: str = None,
) -> LLMModelAdaper:
if model_type == ModelType.VLLM:
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
return VLLMModelAdaperWrapper()
must_use_old = any(m in model_name for m in _OLD_MODELS)
if use_fastchat and not must_use_old:
logger.info("Use fastcat adapter")
adapter = _get_fastchat_model_adapter(
model_name,
model_path,
_fastchat_get_adapter_monkey_patch,
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
)
return FastChatLLMModelAdaperWrapper(adapter)
else:
from pilot.model.adapter import (
get_llm_model_adapter as _old_get_llm_model_adapter,
)
from pilot.server.chat_adapter import get_llm_chat_adapter
logger.info("Use DB-GPT old adapter")
return OldLLMModelAdaperWrapper(
_old_get_llm_model_adapter(model_name, model_path),
get_llm_chat_adapter(model_name, model_path),
)
def _get_fastchat_model_adapter(
model_name: str,
model_path: str,
caller: Callable[[str], None] = None,
use_fastchat_monkey_patch: bool = False,
):
from fastchat.model import model_adapter
_bak_get_model_adapter = model_adapter.get_model_adapter
try:
if use_fastchat_monkey_patch:
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
thread_local.model_name = model_name
if caller:
return caller(model_path)
finally:
del thread_local.model_name
model_adapter.get_model_adapter = _bak_get_model_adapter
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
if not model_name:
if not hasattr(thread_local, "model_name"):
raise RuntimeError("fastchat get adapter monkey path need model_name")
model_name = thread_local.model_name
from fastchat.model.model_adapter import model_adapters
for adapter in model_adapters:
if adapter.match(model_name):
logger.info(
f"Found llm model adapter with model name: {model_name}, {adapter}"
)
return adapter
model_path_basename = (
None if not model_path else os.path.basename(os.path.normpath(model_path))
)
for adapter in model_adapters:
if model_path_basename and adapter.match(model_path_basename):
logger.info(
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
)
return adapter
for adapter in model_adapters:
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
return adapter
raise ValueError(
f"Invalid model adapter for model name {model_name} and model path {model_path}"
)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
from pilot.model.parameter import (
EmbeddingModelParameters,
WorkerType,
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
)
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
worker_type = pre_args.get("worker_type")
model_type = pre_args.get("model_type")
if model_name is None and model_type != ModelType.VLLM:
return None
if worker_type == WorkerType.TEXT2VEC:
return [
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
model_name, EmbeddingModelParameters
)
]
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
param_class = llm_adapter.model_param_class()
return [param_class]
class VLLMModelAdaperWrapper(LLMModelAdaper):
"""Wrapping vllm engine"""
def model_type(self) -> str:
return ModelType.VLLM
def model_param_class(self, model_type: str = None) -> ModelParameters:
import argparse
from vllm.engine.arg_utils import AsyncEngineArgs
parser = argparse.ArgumentParser()
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument("--model_name", type=str, help="model name")
parser.add_argument(
"--model_path",
type=str,
help="local model path of the huggingface model to use",
)
parser.add_argument("--model_type", type=str, help="model type")
parser.add_argument("--device", type=str, default=None, help="device")
# TODO parse prompt templete from `model_name` and `model_path`
parser.add_argument(
"--prompt_template",
type=str,
default=None,
help="Prompt template. If None, the prompt template is automatically determined from model path",
)
descs = _extract_parameter_details(
parser,
"pilot.model.parameter.VLLMModelParameters",
skip_names=["model"],
overwrite_default_values={"trust_remote_code": True},
)
return _build_parameter_class(descs)
def load_from_params(self, params):
from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
import torch
num_gpus = torch.cuda.device_count()
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
setattr(params, "tensor_parallel_size", num_gpus)
logger.info(
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
)
params = dataclasses.asdict(params)
params["model"] = params["model_path"]
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
# Set the attributes from the parsed arguments.
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
engine = AsyncLLMEngine.from_engine_args(engine_args)
return engine, engine.engine.tokenizer
def support_async(self) -> bool:
return True
def get_async_generate_stream_function(self, model, model_path: str):
from pilot.model.llm_out.vllm_llm import generate_stream
return generate_stream
def get_default_conv_template(
self, model_name: str, model_path: str
) -> "Conversation":
return _auto_get_conv_template(model_name, model_path)
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
# We also recommend that you modify it directly in the fastchat repository.
register_conv_template(
Conversation(
name="internlm-chat",
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
roles=("<|User|>", "<|Bot|>"),
sep_style=SeparatorStyle.CHATINTERN,
sep="<eoh>",
sep2="<eoa>",
stop_token_ids=[1, 103028],
# TODO feedback stop_str to fastchat
stop_str="<eoa>",
),
override=True,
)

View File

@ -64,6 +64,13 @@ class ModelWorkerParameters(BaseModelParameters):
default=None,
metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
)
model_type: Optional[str] = field(
default="huggingface",
metadata={
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed",
},
)
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
)
@ -163,7 +170,7 @@ class ModelParameters(BaseModelParameters):
model_type: Optional[str] = field(
default="huggingface",
metadata={
"help": "Model type, huggingface, llama.cpp and proxy",
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed",
},
)
@ -292,7 +299,7 @@ class ProxyModelParameters(BaseModelParameters):
model_type: Optional[str] = field(
default="proxy",
metadata={
"help": "Model type, huggingface, llama.cpp and proxy",
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
"tags": "fixed",
},
)

View File

@ -140,7 +140,7 @@ class BaseChat(ABC):
payload = self.__call_base()
self.skip_echo_len = len(payload.get("prompt").replace("</s>", " ")) + 11
logger.info(f"Requert: \n{payload}")
logger.info(f"Request: \n{payload}")
ai_response_text = ""
try:
from pilot.model.cluster import WorkerManagerFactory

View File

@ -19,7 +19,7 @@ def signal_handler(sig, frame):
os._exit(0)
def async_db_summery(system_app: SystemApp):
def async_db_summary(system_app: SystemApp):
from pilot.summary.db_summary_client import DBSummaryClient
client = DBSummaryClient(system_app=system_app)
@ -79,7 +79,7 @@ def _create_model_start_listener(system_app: SystemApp):
print("begin run _add_app_startup_event")
conn_manage = ConnectManager(system_app)
cfg.LOCAL_DB_MANAGE = conn_manage
async_db_summery(system_app)
async_db_summary(system_app)
return startup_event

View File

@ -1,3 +1,7 @@
"""
This code file will be deprecated in the future.
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
"""
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

View File

View File

@ -1,18 +1,18 @@
class DBSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.summary = None
self.tables = []
self.metadata = str
def get_summery(self):
return self.summery
def get_summary(self):
return self.summary
class TableSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.summary = None
self.fields = []
self.indexes = []
@ -20,12 +20,12 @@ class TableSummary:
class FieldSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.summary = None
self.data_type = None
class IndexSummary:
def __init__(self, name):
self.name = name
self.summery = None
self.summary = None
self.bind_fields = []

View File

@ -47,13 +47,13 @@ class DBSummaryClient:
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_summery(),
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
self.init_db_profile(db_summary_client, dbname, embeddings)
if not embedding.vector_name_exist():
if CFG.SUMMARY_CONFIG == "FAST":
for vector_table_info in db_summary_client.get_summery():
for vector_table_info in db_summary_client.get_summary():
embedding = StringEmbedding(
vector_table_info,
vector_store_config,
@ -61,7 +61,7 @@ class DBSummaryClient:
embedding.source_embedding()
else:
embedding = StringEmbedding(
file_path=db_summary_client.get_summery(),
file_path=db_summary_client.get_summary(),
vector_store_config=vector_store_config,
)
embedding.source_embedding()
@ -144,8 +144,8 @@ class DBSummaryClient:
vector_store_config=vector_store_config,
embedding_factory=embedding_factory,
)
table_summery = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summery[0].page_content)
table_summary = knowledge_embedding_client.similar_search(query, 1)
related_table_summaries.append(table_summary[0].page_content)
return related_table_summaries
def init_db_summary(self):
@ -169,7 +169,7 @@ class DBSummaryClient:
"embeddings": embeddings,
}
embedding = StringEmbedding(
file_path=db_summary_client.get_db_summery(),
file_path=db_summary_client.get_db_summary(),
vector_store_config=profile_store_config,
)
if not embedding.vector_name_exist():

View File

@ -12,7 +12,7 @@ class RdbmsSummary(DBSummary):
def __init__(self, name, type):
self.name = name
self.type = type
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
self.tables = {}
self.tables_info = []
self.vector_tables_info = []
@ -48,7 +48,7 @@ class RdbmsSummary(DBSummary):
for table_name in tables:
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
# self.tables[table_name] = table_summary.get_summery()
# self.tables[table_name] = table_summary.get_summary()
self.tables[table_name] = table_summary.get_columns()
self.table_columns_info.append(table_summary.get_columns())
# self.table_columns_json.append(table_summary.get_summary_json())
@ -59,18 +59,18 @@ class RdbmsSummary(DBSummary):
)
)
self.table_columns_json.append(table_profile)
# self.tables_info.append(table_summary.get_summery())
# self.tables_info.append(table_summary.get_summary())
def get_summery(self):
def get_summary(self):
if CFG.SUMMARY_CONFIG == "FAST":
return self.vector_tables_info
else:
return self.summery.format(
return self.summary.format(
name=self.name, type=self.type, table_info=";".join(self.tables_info)
)
def get_db_summery(self):
return self.summery.format(
def get_db_summary(self):
return self.summary.format(
name=self.name,
type=self.type,
tables=";".join(self.vector_tables_info),
@ -94,8 +94,8 @@ class RdbmsTableSummary(TableSummary):
def __init__(self, instance, dbname, name, comment_map):
self.name = name
self.dbname = dbname
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
self.fields = []
self.fields_info = []
self.indexes = []
@ -107,19 +107,19 @@ class RdbmsTableSummary(TableSummary):
for field in fields:
field_summary = RdbmsFieldsSummary(field)
self.fields.append(field_summary)
self.fields_info.append(field_summary.get_summery())
self.fields_info.append(field_summary.get_summary())
field_names.append(field[0])
self.column_summery = """{name}({columns_info})""".format(
self.column_summary = """{name}({columns_info})""".format(
name=name, columns_info=",".join(field_names)
)
for index in indexes:
index_summary = RdbmsIndexSummary(index)
self.indexes.append(index_summary)
self.indexes_info.append(index_summary.get_summery())
self.indexes_info.append(index_summary.get_summary())
self.json_summery = self.json_summery_template.format(
self.json_summary = self.json_summary_template.format(
name=name,
comment=comment_map[name],
fields=self.fields_info,
@ -128,8 +128,8 @@ class RdbmsTableSummary(TableSummary):
rows=1000,
)
def get_summery(self):
return self.summery.format(
def get_summary(self):
return self.summary.format(
name=self.name,
dbname=self.dbname,
fields=";".join(self.fields_info),
@ -137,10 +137,10 @@ class RdbmsTableSummary(TableSummary):
)
def get_columns(self):
return self.column_summery
return self.column_summary
def get_summary_json(self):
return self.json_summery
return self.json_summary
class RdbmsFieldsSummary(FieldSummary):
@ -148,14 +148,14 @@ class RdbmsFieldsSummary(FieldSummary):
def __init__(self, field):
self.name = field[0]
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
# self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
# self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
self.data_type = field[1]
self.default_value = field[2]
self.is_nullable = field[3]
self.comment = field[4]
def get_summery(self):
def get_summary(self):
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
name=self.name,
data_type=self.data_type,
@ -170,11 +170,11 @@ class RdbmsIndexSummary(IndexSummary):
def __init__(self, index):
self.name = index[0]
# self.summery = """index name:{name}, index bind columns:{bind_fields}"""
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
# self.summary = """index name:{name}, index bind columns:{bind_fields}"""
self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
self.bind_fields = index[1]
def get_summery(self):
return self.summery_template.format(
def get_summary(self):
return self.summary_template.format(
name=self.name, bind_fields=self.bind_fields
)

View File

@ -86,17 +86,7 @@ class BaseParameters:
return updated
def __str__(self) -> str:
class_name = self.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = _get_simple_privacy_field_value(self, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
return _get_dataclass_print_str(self)
def to_command_args(self, args_prefix: str = "--") -> List[str]:
"""Convert the fields of the dataclass to a list of command line arguments.
@ -110,6 +100,20 @@ class BaseParameters:
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
def _get_dataclass_print_str(obj):
class_name = obj.__class__.__name__
parameters = [
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(obj):
value = _get_simple_privacy_field_value(obj, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
)
return "\n".join(parameters)
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
"""Convert dict to a list of command line arguments
@ -493,9 +497,10 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
if not desc:
raise ValueError("Parameter descriptions cant be empty")
param_class_str = desc[0].param_class
param_class = import_from_string(param_class_str, ignore_import_error=True)
if param_class:
return param_class
if param_class_str:
param_class = import_from_string(param_class_str, ignore_import_error=True)
if param_class:
return param_class
module_name, _, class_name = param_class_str.rpartition(".")
fields_dict = {} # This will store field names and their default values or field()
@ -520,6 +525,71 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
return result_class
def _extract_parameter_details(
parser: argparse.ArgumentParser,
param_class: str = None,
skip_names: List[str] = None,
overwrite_default_values: Dict = {},
) -> List[ParameterDescription]:
descriptions = []
for action in parser._actions:
if (
action.default == argparse.SUPPRESS
): # typically this means the argument was not provided
continue
# determine parameter class (store_true/store_false are flags)
flag_or_option = (
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
)
# extract parameter name (use the first option string, typically the long form)
param_name = action.option_strings[0] if action.option_strings else action.dest
if param_name.startswith("--"):
param_name = param_name[2:]
if param_name.startswith("-"):
param_name = param_name[1:]
param_name = param_name.replace("-", "_")
if skip_names and param_name in skip_names:
continue
# gather other details
default_value = action.default
if param_name in overwrite_default_values:
default_value = overwrite_default_values[param_name]
arg_type = (
action.type if not callable(action.type) else str(action.type.__name__)
)
description = action.help
# determine if the argument is required
required = action.required
# extract valid values for choices, if provided
valid_values = action.choices if action.choices is not None else None
# set ext_metadata as an empty dict for now, can be updated later if needed
ext_metadata = {}
descriptions.append(
ParameterDescription(
param_class=param_class,
param_name=param_name,
param_type=arg_type,
default_value=default_value,
description=description,
required=required,
valid_values=valid_values,
ext_metadata=ext_metadata,
)
)
return descriptions
class _SimpleArgParser:
def __init__(self, *args):
self.params = {arg.replace("_", "-"): None for arg in args}

View File

View File

@ -0,0 +1,81 @@
import argparse
import pytest
from pilot.utils.parameter_utils import _extract_parameter_details
def create_parser():
parser = argparse.ArgumentParser()
return parser
@pytest.mark.parametrize(
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
[
("--option", "option", "value", str, "str", "An option argument"),
("-option", "option", "value", str, "str", "An option argument"),
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
],
)
def test_extract_parameter_details_option_argument(
argument,
expected_param_name,
default_value,
param_type,
expected_param_type,
description,
):
parser = create_parser()
parser.add_argument(
argument, default=default_value, type=param_type, help=description
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == expected_param_name
assert desc.param_type == expected_param_type
assert desc.default_value == default_value
assert desc.description == description
assert desc.required == False
assert desc.valid_values is None
def test_extract_parameter_details_flag_argument():
parser = create_parser()
parser.add_argument("--flag", action="store_true", help="A flag argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "flag"
assert desc.description == "A flag argument"
assert desc.required == False
def test_extract_parameter_details_choice_argument():
parser = create_parser()
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "choice"
assert desc.valid_values == ["A", "B", "C"]
def test_extract_parameter_details_required_argument():
parser = create_parser()
parser.add_argument(
"--required", required=True, type=int, help="A required argument"
)
descriptions = _extract_parameter_details(parser)
assert len(descriptions) == 1
desc = descriptions[0]
assert desc.param_name == "required"
assert desc.required == True

View File

@ -9,4 +9,6 @@ pytest-mock
pytest-recording
pytesseract==0.3.10
# python code format
black
black
# for git hooks
pre-commmit

View File

@ -203,9 +203,9 @@ def get_cuda_version() -> str:
def torch_requires(
torch_version: str = "2.0.0",
torchvision_version: str = "0.15.1",
torchaudio_version: str = "2.0.1",
torch_version: str = "2.0.1",
torchvision_version: str = "0.15.2",
torchaudio_version: str = "2.0.2",
):
torch_pkgs = [
f"torch=={torch_version}",
@ -298,6 +298,7 @@ def core_requires():
]
setup_spec.extras["framework"] = [
"fschat",
"coloredlogs",
"httpx",
"sqlparse==0.4.4",
@ -396,12 +397,19 @@ def gpt4all_requires():
setup_spec.extras["gpt4all"] = ["gpt4all"]
def vllm_requires():
"""
pip install "db-gpt[vllm]"
"""
setup_spec.extras["vllm"] = ["vllm"]
def default_requires():
"""
pip install "db-gpt[default]"
"""
setup_spec.extras["default"] = [
"tokenizers==0.13.2",
"tokenizers==0.13.3",
"accelerate>=0.20.3",
"sentence-transformers",
"protobuf==3.20.3",
@ -435,6 +443,7 @@ all_vector_store_requires()
all_datasource_requires()
openai_requires()
gpt4all_requires()
vllm_requires()
# must be last
default_requires()