diff --git a/.gitignore b/.gitignore index d9fbfd59b..cbebc0334 100644 --- a/.gitignore +++ b/.gitignore @@ -151,4 +151,6 @@ pilot/mock_datas/db-gpt-test.db.wal logswebserver.log.* .history/* -.plugin_env \ No newline at end of file +.plugin_env +# Ignore for now +thirdparty \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1f44571ac..172067442 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,11 +10,11 @@ git clone https://github.com//DB-GPT ``` 3. Install the project requirements ``` -pip install -r requirements.txt +pip install -r requirements/dev-requirements.txt ``` 4. Install pre-commit hooks ``` -pre-commit install +pre-commit install --allow-missing-config ``` 5. Create a new branch for your changes using the following command: diff --git a/docker-compose.yml b/docker-compose.yml index 24856bff2..38d9915e0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -41,6 +41,7 @@ services: restart: unless-stopped networks: - dbgptnet + ipc: host deploy: resources: reservations: diff --git a/docker/base/run_sqlite.sh b/docker/base/run_sqlite.sh index ed06c2808..3915ac30c 100755 --- a/docker/base/run_sqlite.sh +++ b/docker/base/run_sqlite.sh @@ -1,6 +1,6 @@ #!/bin/bash -docker run --gpus all -d \ +docker run --ipc host --gpus all -d \ -p 5000:5000 \ -e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_PATH=data/default_sqlite.db \ diff --git a/docker/base/run_sqlite_proxyllm.sh b/docker/base/run_sqlite_proxyllm.sh index a707953cc..45d4730b2 100755 --- a/docker/base/run_sqlite_proxyllm.sh +++ b/docker/base/run_sqlite_proxyllm.sh @@ -4,7 +4,7 @@ PROXY_API_KEY="$PROXY_API_KEY" PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}" -docker run --gpus all -d \ +docker run --ipc host --gpus all -d \ -p 5000:5000 \ -e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_PATH=data/default_sqlite.db \ diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml index df2faed91..b41033458 100644 --- a/docker/compose_examples/cluster-docker-compose.yml +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -21,6 +21,7 @@ services: restart: unless-stopped networks: - dbgptnet + ipc: host deploy: resources: reservations: diff --git a/docs/getting_started/install/docker/docker.md b/docs/getting_started/install/docker/docker.md index 377ad0297..f0ba2f331 100644 --- a/docs/getting_started/install/docker/docker.md +++ b/docs/getting_started/install/docker/docker.md @@ -47,7 +47,7 @@ You can execute the command `bash docker/build_all_images.sh --help` to see more **Run with local model and SQLite database** ```bash -docker run --gpus all -d \ +docker run --ipc host --gpus all -d \ -p 5000:5000 \ -e LOCAL_DB_TYPE=sqlite \ -e LOCAL_DB_PATH=data/default_sqlite.db \ @@ -73,7 +73,7 @@ docker logs dbgpt -f **Run with local model and MySQL database** ```bash -docker run --gpus all -d -p 3306:3306 \ +docker run --ipc host --gpus all -d -p 3306:3306 \ -p 5000:5000 \ -e LOCAL_DB_HOST=127.0.0.1 \ -e LOCAL_DB_PASSWORD=aa123456 \ diff --git a/docs/getting_started/install/llm/llm.rst b/docs/getting_started/install/llm/llm.rst index c20d19c89..185fd4b4f 100644 --- a/docs/getting_started/install/llm/llm.rst +++ b/docs/getting_started/install/llm/llm.rst @@ -30,3 +30,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin ./llama/llama_cpp.md ./quantization/quantization.md + ./vllm/vllm.md diff --git a/docs/getting_started/install/llm/vllm/vllm.md b/docs/getting_started/install/llm/vllm/vllm.md new file mode 100644 index 000000000..c4d0d2137 --- /dev/null +++ b/docs/getting_started/install/llm/vllm/vllm.md @@ -0,0 +1,26 @@ +vLLM +================================== + +[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving. + +## Running vLLM + +### Installing Dependencies + +vLLM is an optional dependency in DB-GPT, and you can manually install it using the following command: + +```bash +pip install -e ".[vllm]" +``` + +### Modifying the Configuration File + +Next, you can directly modify your `.env` file to enable vllm. + +```env +LLM_MODEL=vicuna-13b-v1.5 +MODEL_TYPE=vllm +``` +You can view the models supported by vLLM [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html#supported-models) + +Then you can run it according to [Run](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run). \ No newline at end of file diff --git a/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/vllm/vllm.po b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/vllm/vllm.po new file mode 100644 index 000000000..66d9ddb71 --- /dev/null +++ b/docs/locales/zh_CN/LC_MESSAGES/getting_started/install/llm/vllm/vllm.po @@ -0,0 +1,79 @@ +# SOME DESCRIPTIVE TITLE. +# Copyright (C) 2023, csunny +# This file is distributed under the same license as the DB-GPT package. +# FIRST AUTHOR , 2023. +# +#, fuzzy +msgid "" +msgstr "" +"Project-Id-Version: DB-GPT 👏👏 0.3.9\n" +"Report-Msgid-Bugs-To: \n" +"POT-Creation-Date: 2023-10-09 19:46+0800\n" +"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" +"Last-Translator: FULL NAME \n" +"Language: zh_CN\n" +"Language-Team: zh_CN \n" +"Plural-Forms: nplurals=1; plural=0;\n" +"MIME-Version: 1.0\n" +"Content-Type: text/plain; charset=utf-8\n" +"Content-Transfer-Encoding: 8bit\n" +"Generated-By: Babel 2.12.1\n" + +#: ../../getting_started/install/llm/vllm/vllm.md:1 +#: 9193438ba52148a3b71f190beeb4ef42 +msgid "vLLM" +msgstr "" + +#: ../../getting_started/install/llm/vllm/vllm.md:4 +#: c30d032965794d7e81636581324be45d +msgid "" +"[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use " +"library for LLM inference and serving." +msgstr "[vLLM](https://github.com/vllm-project/vllm) 是一个快速且易于使用的 LLM 推理和服务的库。" + +#: ../../getting_started/install/llm/vllm/vllm.md:6 +#: b399c7268e0448cb893fbcb11a480849 +msgid "Running vLLM" +msgstr "运行 vLLM" + +#: ../../getting_started/install/llm/vllm/vllm.md:8 +#: 7bed52b8bac946069a24df9e94098df5 +msgid "Installing Dependencies" +msgstr "安装依赖" + +#: ../../getting_started/install/llm/vllm/vllm.md:10 +#: fd50a9f3e1b1459daa3b1a0cd610d1a3 +msgid "" +"vLLM is an optional dependency in DB-GPT, and you can manually install it" +" using the following command:" +msgstr "vLLM 在 DB-GPT 是一个可选依赖, 你可以使用下面的命令手动安装它:" + +#: ../../getting_started/install/llm/vllm/vllm.md:16 +#: 44b251bc6b2c41ebaad9fd5a6a204c7c +msgid "Modifying the Configuration File" +msgstr "修改配置文件" + +#: ../../getting_started/install/llm/vllm/vllm.md:18 +#: 37f4e65148fa4339969265107b70b8fe +msgid "Next, you can directly modify your `.env` file to enable vllm." +msgstr "你可以直接修改你的 `.env` 文件。" + +#: ../../getting_started/install/llm/vllm/vllm.md:24 +#: 15d79c9417d04e779fa00a08a05e30d7 +msgid "" +"You can view the models supported by vLLM " +"[here](https://vllm.readthedocs.io/en/latest/models/supported_models.html" +"#supported-models)" +msgstr "" +"你可以在 " +"[这里](https://vllm.readthedocs.io/en/latest/models/supported_models.html" +"#supported-models) 查看 vLLM 支持的模型。" + +#: ../../getting_started/install/llm/vllm/vllm.md:26 +#: 28d90b1fdf6943d9969c8668a7c1094b +msgid "" +"Then you can run it according to [Run](https://db-" +"gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run)." +msgstr "" +"然后你可以根据[运行]" +"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/getting_started/install/deploy/deploy.html#run)来启动项目。" diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 134bd66f6..d7912923f 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -201,6 +201,9 @@ class Config(metaclass=Singleton): self.SYSTEM_APP: Optional["SystemApp"] = None + ### Temporary configuration + self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true" + def set_debug_mode(self, value: bool) -> None: """Set the debug mode value""" self.debug_mode = value diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 927272cb1..6408aecd6 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -76,6 +76,8 @@ LLM_MODEL_CONFIG = { "internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"), "internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"), "internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"), + # For test now + "opt-125m": os.path.join(MODEL_PATH, "opt-125m"), } EMBEDDING_MODEL_CONFIG = { diff --git a/pilot/conversation.py b/pilot/conversation.py index aa31a3b23..814943a49 100644 --- a/pilot/conversation.py +++ b/pilot/conversation.py @@ -29,7 +29,7 @@ class SeparatorStyle(Enum): @dataclasses.dataclass -class Conversation: +class OldConversation: """This class keeps all conversation history.""" system: str @@ -81,7 +81,7 @@ class Conversation: return ret def copy(self): - return Conversation( + return OldConversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], @@ -104,7 +104,7 @@ class Conversation: } -conv_default = Conversation( +conv_default = OldConversation( system=None, roles=("human", "ai"), messages=[], @@ -148,7 +148,7 @@ conv_default = Conversation( # ) -conv_one_shot = Conversation( +conv_one_shot = OldConversation( system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:", roles=("USER", "ASSISTANT"), messages=( @@ -179,7 +179,7 @@ conv_one_shot = Conversation( sep="###", ) -conv_vicuna_v1 = Conversation( +conv_vicuna_v1 = OldConversation( system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. " "The assistant gives helpful, detailed, professional and polite answers to the user's questions. ", roles=("USER", "ASSISTANT"), @@ -190,7 +190,7 @@ conv_vicuna_v1 = Conversation( sep2="", ) -auto_dbgpt_one_shot = Conversation( +auto_dbgpt_one_shot = OldConversation( system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. " "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", roles=("USER", "ASSISTANT"), @@ -263,7 +263,7 @@ auto_dbgpt_one_shot = Conversation( sep="###", ) -auto_dbgpt_without_shot = Conversation( +auto_dbgpt_without_shot = OldConversation( system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. " "Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.", roles=("USER", "ASSISTANT"), diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index 0ed42abc9..763d27059 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -1,3 +1,7 @@ +""" +This code file will be deprecated in the future. +We have integrated fastchat. For details, see: pilot/model/model_adapter.py +""" #!/usr/bin/env python3 # -*- coding: utf-8 -*- @@ -13,6 +17,8 @@ from transformers import ( AutoTokenizer, LlamaTokenizer, ) +from pilot.model.base import ModelType + from pilot.model.parameter import ( ModelParameters, LlamaCppModelParameters, @@ -26,15 +32,6 @@ logger = logging.getLogger(__name__) CFG = Config() -class ModelType: - """ "Type of model""" - - HF = "huggingface" - LLAMA_CPP = "llama.cpp" - PROXY = "proxy" - # TODO, support more model type - - class BaseLLMAdaper: """The Base class for multi model, in our project. We will support those model, which performance resemble ChatGPT""" @@ -95,33 +92,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: ) -def _dynamic_model_parser() -> Callable[[None], List[Type]]: - from pilot.utils.parameter_utils import _SimpleArgParser - from pilot.model.parameter import ( - EmbeddingModelParameters, - WorkerType, - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, - ) - - pre_args = _SimpleArgParser("model_name", "model_path", "worker_type") - pre_args.parse() - model_name = pre_args.get("model_name") - model_path = pre_args.get("model_path") - worker_type = pre_args.get("worker_type") - if model_name is None: - return None - if worker_type == WorkerType.TEXT2VEC: - return [ - EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( - model_name, EmbeddingModelParameters - ) - ] - - llm_adapter = get_llm_model_adapter(model_name, model_path) - param_class = llm_adapter.model_param_class() - return [param_class] - - def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters: try: llm_adapter = get_llm_model_adapter(model_name, model_path) diff --git a/pilot/model/base.py b/pilot/model/base.py index 1d46b3161..697253f05 100644 --- a/pilot/model/base.py +++ b/pilot/model/base.py @@ -15,6 +15,16 @@ class Message(TypedDict): content: str +class ModelType: + """ "Type of model""" + + HF = "huggingface" + LLAMA_CPP = "llama.cpp" + PROXY = "proxy" + VLLM = "vllm" + # TODO, support more model type + + @dataclass class ModelInstance: """Model instance info""" diff --git a/pilot/model/cli.py b/pilot/model/cli.py index 3e94c7045..1030adfc2 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -404,7 +404,7 @@ def stop_model_controller(port: int): def _model_dynamic_factory() -> Callable[[None], List[Type]]: - from pilot.model.adapter import _dynamic_model_parser + from pilot.model.model_adapter import _dynamic_model_parser param_class = _dynamic_model_parser() fix_class = [ModelWorkerParameters] diff --git a/pilot/model/cluster/worker/default_worker.py b/pilot/model/cluster/worker/default_worker.py index c210fcb44..9ccf18b52 100644 --- a/pilot/model/cluster/worker/default_worker.py +++ b/pilot/model/cluster/worker/default_worker.py @@ -1,26 +1,34 @@ +import os import logging from typing import Dict, Iterator, List from pilot.configs.model_config import get_device -from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper +from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper from pilot.model.base import ModelOutput from pilot.model.loader import ModelLoader, _get_model_real_path from pilot.model.parameter import ModelParameters from pilot.model.cluster.worker_base import ModelWorker -from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter from pilot.utils.model_utils import _clear_model_cache from pilot.utils.parameter_utils import EnvArgumentParser logger = logging.getLogger(__name__) +_torch_imported = False +try: + import torch + + _torch_imported = True +except ImportError: + pass + class DefaultModelWorker(ModelWorker): def __init__(self) -> None: self.model = None self.tokenizer = None self._model_params = None - self.llm_adapter: BaseLLMAdaper = None - self.llm_chat_adapter: BaseChatAdpter = None + self.llm_adapter: LLMModelAdaper = None + self._support_async = False def load_worker(self, model_name: str, model_path: str, **kwargs) -> None: if model_path.endswith("/"): @@ -29,18 +37,24 @@ class DefaultModelWorker(ModelWorker): self.model_name = model_name self.model_path = model_path - self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + model_type = kwargs.get("model_type") + ### Temporary configuration, fastchat will be used by default in the future. + use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true" + + self.llm_adapter = get_llm_model_adapter( + self.model_name, + self.model_path, + use_fastchat=use_fastchat, + model_type=model_type, + ) model_type = self.llm_adapter.model_type() self.param_cls = self.llm_adapter.model_param_class(model_type) + self._support_async = self.llm_adapter.support_async() + logger.info( f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}" ) - self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path) - self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( - self.model_path - ) - self.ml: ModelLoader = ModelLoader( model_path=self.model_path, model_name=self.model_name ) @@ -50,6 +64,9 @@ class DefaultModelWorker(ModelWorker): def model_param_class(self) -> ModelParameters: return self.param_cls + def support_async(self) -> bool: + return self._support_async + def parse_parameters(self, command_args: List[str] = None) -> ModelParameters: param_cls = self.model_param_class() model_args = EnvArgumentParser() @@ -77,7 +94,9 @@ class DefaultModelWorker(ModelWorker): model_params = self.parse_parameters(command_args) self._model_params = model_params logger.info(f"Begin load model, model params: {model_params}") - self.model, self.tokenizer = self.ml.loader_with_params(model_params) + self.model, self.tokenizer = self.ml.loader_with_params( + model_params, self.llm_adapter + ) def stop(self) -> None: if not self.model: @@ -90,51 +109,26 @@ class DefaultModelWorker(ModelWorker): _clear_model_cache(self._model_params.device) def generate_stream(self, params: Dict) -> Iterator[ModelOutput]: - torch_imported = False try: - import torch - - torch_imported = True - except ImportError: - pass - try: - # params adaptation - params, model_context = self.llm_chat_adapter.model_adaptation( - params, self.ml.model_path, prompt_template=self.ml.prompt_template + params, model_context, generate_stream_func = self._prepare_generate_stream( + params ) previous_response = "" - print("stream output:\n") - for output in self.generate_stream_func( + + for output in generate_stream_func( self.model, self.tokenizer, params, get_device(), self.context_len ): - # Please do not open the output in production! - # The gpt4all thread shares stdout with the parent process, - # and opening it may affect the frontend output. - incremental_output = output[len(previous_response) :] - # print("output: ", output) - print(incremental_output, end="", flush=True) - previous_response = output - # return some model context to dgt-server - model_output = ModelOutput( - text=output, error_code=0, model_context=model_context + model_output, incremental_output, output_str = self._handle_output( + output, previous_response, model_context ) + previous_response = output_str yield model_output print( f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" ) except Exception as e: - # Check if the exception is a torch.cuda.CudaError and if torch was imported. - if torch_imported and isinstance(e, torch.cuda.CudaError): - model_output = ModelOutput( - text="**GPU OutOfMemory, Please Refresh.**", error_code=0 - ) - else: - model_output = ModelOutput( - text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", - error_code=0, - ) - yield model_output + yield self._handle_exception(e) def generate(self, params: Dict) -> ModelOutput: """Generate non stream result""" @@ -145,3 +139,81 @@ class DefaultModelWorker(ModelWorker): def embeddings(self, params: Dict) -> List[List[float]]: raise NotImplementedError + + async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]: + try: + params, model_context, generate_stream_func = self._prepare_generate_stream( + params + ) + + previous_response = "" + + async for output in generate_stream_func( + self.model, self.tokenizer, params, get_device(), self.context_len + ): + model_output, incremental_output, output_str = self._handle_output( + output, previous_response, model_context + ) + previous_response = output_str + yield model_output + print( + f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}" + ) + except Exception as e: + yield self._handle_exception(e) + + async def async_generate(self, params: Dict) -> ModelOutput: + output = None + async for out in self.async_generate_stream(params): + output = out + return output + + def _prepare_generate_stream(self, params: Dict): + params, model_context = self.llm_adapter.model_adaptation( + params, + self.model_name, + self.model_path, + prompt_template=self.ml.prompt_template, + ) + stream_type = "" + if self.support_async(): + generate_stream_func = self.llm_adapter.get_async_generate_stream_function( + self.model, self.model_path + ) + stream_type = "async " + logger.info( + "current generate stream function is asynchronous stream function" + ) + else: + generate_stream_func = self.llm_adapter.get_generate_stream_function( + self.model, self.model_path + ) + str_prompt = params.get("prompt") + print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n") + return params, model_context, generate_stream_func + + def _handle_output(self, output, previous_response, model_context): + if isinstance(output, dict): + finish_reason = output.get("finish_reason") + output = output["text"] + if finish_reason is not None: + logger.info(f"finish_reason: {finish_reason}") + incremental_output = output[len(previous_response) :] + print(incremental_output, end="", flush=True) + model_output = ModelOutput( + text=output, error_code=0, model_context=model_context + ) + return model_output, incremental_output, output + + def _handle_exception(self, e): + # Check if the exception is a torch.cuda.CudaError and if torch was imported. + if _torch_imported and isinstance(e, torch.cuda.CudaError): + model_output = ModelOutput( + text="**GPU OutOfMemory, Please Refresh.**", error_code=0 + ) + else: + model_output = ModelOutput( + text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}", + error_code=0, + ) + return model_output diff --git a/pilot/model/conversation.py b/pilot/model/conversation.py index c8cb2a74b..b3674e946 100644 --- a/pilot/model/conversation.py +++ b/pilot/model/conversation.py @@ -3,7 +3,9 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conver Conversation prompt templates. -TODO Using fastchat core package + +This code file will be deprecated in the future. +We have integrated fastchat. For details, see: pilot/model/model_adapter.py """ import dataclasses diff --git a/pilot/model/inference.py b/pilot/model/inference.py index 1cf681ebe..9c99c8ad0 100644 --- a/pilot/model/inference.py +++ b/pilot/model/inference.py @@ -1,6 +1,8 @@ """ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py +This code file will be deprecated in the future. +We have integrated fastchat. For details, see: pilot/model/model_adapter.py """ #!/usr/bin/env python3 # -*- coding: utf-8 -*- diff --git a/pilot/model/llm/llm_utils.py b/pilot/model/llm/llm_utils.py index 8a9feda65..eb4c83311 100644 --- a/pilot/model/llm/llm_utils.py +++ b/pilot/model/llm/llm_utils.py @@ -7,7 +7,6 @@ import time from typing import Optional from pilot.configs.config import Config -from pilot.conversation import Conversation # TODO Rewrite this @@ -44,36 +43,6 @@ def retry_stream_api( return _wrapper -# Overly simple abstraction util we create something better -# simple retry mechanism when getting a rate error or a bad gateway -def create_chat_competion( - conv: Conversation, - model: Optional[str] = None, - temperature: float = None, - max_new_tokens: Optional[int] = None, -) -> str: - """Create a chat completion using the Vicuna-13b - - Args: - messages(List[Message]): The messages to send to the chat completion - model (str, optional): The model to use. Default to None. - temperature (float, optional): The temperature to use. Defaults to 0.7. - max_tokens (int, optional): The max tokens to use. Defaults to None. - - Returns: - str: The response from the chat completion - """ - cfg = Config() - if temperature is None: - temperature = cfg.temperature - - # TODO request vicuna model get response - # convert vicuna message to chat completion. - for plugin in cfg.plugins: - if plugin.can_handle_chat_completion(): - pass - - class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: diff --git a/pilot/model/llm_out/guanaco_llm.py b/pilot/model/llm_out/guanaco_llm.py index 1a2d1ae8b..dd727d19a 100644 --- a/pilot/model/llm_out/guanaco_llm.py +++ b/pilot/model/llm_out/guanaco_llm.py @@ -1,7 +1,6 @@ import torch from threading import Thread from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria -from pilot.conversation import ROLE_ASSISTANT, ROLE_USER def guanaco_generate_output(model, tokenizer, params, device, context_len=2048): diff --git a/pilot/model/llm_out/vllm_llm.py b/pilot/model/llm_out/vllm_llm.py new file mode 100644 index 000000000..07d43dc74 --- /dev/null +++ b/pilot/model/llm_out/vllm_llm.py @@ -0,0 +1,56 @@ +from typing import Dict +from vllm import AsyncLLMEngine +from vllm.utils import random_uuid +from vllm.sampling_params import SamplingParams + + +async def generate_stream( + model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int +): + """ + Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py + """ + prompt = params["prompt"] + request_id = params.pop("request_id") if "request_id" in params else random_uuid() + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + max_new_tokens = int(params.get("max_new_tokens", 2048)) + echo = bool(params.get("echo", True)) + stop_str = params.get("stop", None) + + stop_token_ids = params.get("stop_token_ids", None) or [] + if tokenizer.eos_token_id is not None: + stop_token_ids.append(tokenizer.eos_token_id) + + # Handle stop_str + stop = set() + if isinstance(stop_str, str) and stop_str != "": + stop.add(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.update(stop_str) + + for tid in stop_token_ids: + if tid is not None: + stop.add(tokenizer.decode(tid)) + + # make sampling params in vllm + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + use_beam_search=False, + stop=list(stop), + max_tokens=max_new_tokens, + ) + results_generator = model.generate(prompt, sampling_params, request_id) + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [prompt + output.text for output in request_output.outputs] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) + yield {"text": text_outputs, "error_code": 0, "usage": {}} diff --git a/pilot/model/llm_utils.py b/pilot/model/llm_utils.py index 690a6afbf..ae711609c 100644 --- a/pilot/model/llm_utils.py +++ b/pilot/model/llm_utils.py @@ -146,7 +146,7 @@ def list_supported_models(): def _list_supported_models( worker_type: str, model_config: Dict[str, str] ) -> List[SupportedModel]: - from pilot.model.adapter import get_llm_model_adapter + from pilot.model.model_adapter import get_llm_model_adapter from pilot.model.parameter import ModelParameters from pilot.model.loader import _get_model_real_path diff --git a/pilot/model/loader.py b/pilot/model/loader.py index 39adf24ad..2f5f10c2d 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -3,9 +3,11 @@ from typing import Optional, Dict +from dataclasses import asdict import logging from pilot.configs.model_config import get_device -from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType +from pilot.model.base import ModelType +from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper from pilot.model.parameter import ( ModelParameters, LlamaCppModelParameters, @@ -114,8 +116,9 @@ class ModelLoader: else: raise Exception(f"Unkown model type {model_type}") - def loader_with_params(self, model_params: ModelParameters): - llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) + def loader_with_params( + self, model_params: ModelParameters, llm_adapter: LLMModelAdaper + ): model_type = llm_adapter.model_type() self.prompt_template = model_params.prompt_template if model_type == ModelType.HF: @@ -124,11 +127,13 @@ class ModelLoader: return llamacpp_loader(llm_adapter, model_params) elif model_type == ModelType.PROXY: return proxyllm_loader(llm_adapter, model_params) + elif model_type == ModelType.VLLM: + return llm_adapter.load_from_params(model_params) else: raise Exception(f"Unkown model type {model_type}") -def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters): +def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters): import torch from pilot.model.compression import compress_module @@ -181,7 +186,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters f"Current model {model_params.model_name} not supported quantization" ) # default loader - model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs) + model, tokenizer = llm_adapter.load(model_params.model_path, kwargs) if model_params.load_8bit and num_gpus == 1 and tokenizer: # TODO merge current code into `load_huggingface_quantization_model` @@ -204,7 +209,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters def load_huggingface_quantization_model( - llm_adapter: BaseLLMAdaper, + llm_adapter: LLMModelAdaper, model_params: ModelParameters, kwargs: Dict, max_memory: Dict[int, str], @@ -339,7 +344,7 @@ def load_huggingface_quantization_model( return model, tokenizer -def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters): +def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters): try: from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel except ImportError as exc: @@ -353,7 +358,7 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam return model, tokenizer -def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters): +def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters): from pilot.model.proxy.llms.proxy_model import ProxyModel logger.info("Load proxyllm") diff --git a/pilot/model/model_adapter.py b/pilot/model/model_adapter.py new file mode 100644 index 000000000..6b354ddfb --- /dev/null +++ b/pilot/model/model_adapter.py @@ -0,0 +1,431 @@ +from __future__ import annotations + +from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING +import dataclasses +import logging +import threading +import os +from functools import cache +from pilot.model.base import ModelType +from pilot.model.parameter import ( + ModelParameters, + LlamaCppModelParameters, + ProxyModelParameters, +) +from pilot.scene.base_message import ModelMessage, ModelMessageRoleType +from pilot.utils.parameter_utils import ( + _extract_parameter_details, + _build_parameter_class, + _get_dataclass_print_str, +) + +try: + from fastchat.conversation import ( + Conversation, + register_conv_template, + SeparatorStyle, + ) +except ImportError as exc: + raise ValueError( + "Could not import python package: fschat " + "Please install fastchat by command `pip install fschat` " + ) from exc + +if TYPE_CHECKING: + from fastchat.model.model_adapter import BaseModelAdapter + from pilot.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper + from torch.nn import Module as TorchNNModule + +logger = logging.getLogger(__name__) + +thread_local = threading.local() + + +_OLD_MODELS = [ + "llama-cpp", + "proxyllm", + "gptj-6b", +] + + +class LLMModelAdaper: + """New Adapter for DB-GPT LLM models""" + + def model_type(self) -> str: + return ModelType.HF + + def model_param_class(self, model_type: str = None) -> ModelParameters: + """Get the startup parameters instance of the model""" + model_type = model_type if model_type else self.model_type() + if model_type == ModelType.LLAMA_CPP: + return LlamaCppModelParameters + elif model_type == ModelType.PROXY: + return ProxyModelParameters + return ModelParameters + + def load(self, model_path: str, from_pretrained_kwargs: dict): + """Load model and tokenizer""" + raise NotImplementedError + + def load_from_params(self, params): + """Load the model and tokenizer according to the given parameters""" + raise NotImplementedError + + def support_async(self) -> bool: + """Whether the loaded model supports asynchronous calls""" + return False + + def get_generate_stream_function(self, model, model_path: str): + """Get the generate stream function of the model""" + raise NotImplementedError + + def get_async_generate_stream_function(self, model, model_path: str): + """Get the asynchronous generate stream function of the model""" + raise NotImplementedError + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> "Conversation": + """Get the default conv template""" + raise NotImplementedError + + def model_adaptation( + self, + params: Dict, + model_name: str, + model_path: str, + prompt_template: str = None, + ) -> Tuple[Dict, Dict]: + """Params adaptation""" + conv = self.get_default_conv_template(model_name, model_path) + messages = params.get("messages") + # Some model scontext to dbgpt server + model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False} + + if messages: + # Dict message to ModelMessage + messages = [ + m if isinstance(m, ModelMessage) else ModelMessage(**m) + for m in messages + ] + params["messages"] = messages + + if prompt_template: + logger.info(f"Use prompt template {prompt_template} from config") + conv = get_conv_template(prompt_template) + if not conv or not messages: + # Nothing to do + logger.info( + f"No conv from model_path {model_path} or no messages in params, {self}" + ) + return params, model_context + + conv = conv.copy() + system_messages = [] + for message in messages: + role, content = None, None + if isinstance(message, ModelMessage): + role = message.role + content = message.content + elif isinstance(message, dict): + role = message["role"] + content = message["content"] + else: + raise ValueError(f"Invalid message type: {message}") + + if role == ModelMessageRoleType.SYSTEM: + # Support for multiple system messages + system_messages.append(content) + elif role == ModelMessageRoleType.HUMAN: + conv.append_message(conv.roles[0], content) + elif role == ModelMessageRoleType.AI: + conv.append_message(conv.roles[1], content) + else: + raise ValueError(f"Unknown role: {role}") + if system_messages: + conv.set_system_message("".join(system_messages)) + + # Add a blank message for the assistant. + conv.append_message(conv.roles[1], None) + new_prompt = conv.get_prompt() + # Overwrite the original prompt + # TODO remote bos token and eos token from tokenizer_config.json of model + prompt_echo_len_char = len(new_prompt.replace("", "").replace("", "")) + model_context["prompt_echo_len_char"] = prompt_echo_len_char + model_context["echo"] = params.get("echo", True) + model_context["has_format_prompt"] = True + params["prompt"] = new_prompt + + # Overwrite model params: + params["stop"] = conv.stop_str + + return params, model_context + + +class OldLLMModelAdaperWrapper(LLMModelAdaper): + """Wrapping old adapter, which may be removed later""" + + def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None: + self._adapter = adapter + self._chat_adapter = chat_adapter + + def model_type(self) -> str: + return self._adapter.model_type() + + def model_param_class(self, model_type: str = None) -> ModelParameters: + return self._adapter.model_param_class(model_type) + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> "Conversation": + return self._chat_adapter.get_conv_template(model_path) + + def load(self, model_path: str, from_pretrained_kwargs: dict): + return self._adapter.loader(model_path, from_pretrained_kwargs) + + def get_generate_stream_function(self, model, model_path: str): + return self._chat_adapter.get_generate_stream_func(model_path) + + +class FastChatLLMModelAdaperWrapper(LLMModelAdaper): + """Wrapping fastchat adapter""" + + def __init__(self, adapter: "BaseModelAdapter") -> None: + self._adapter = adapter + + def load(self, model_path: str, from_pretrained_kwargs: dict): + return self._adapter.load_model(model_path, from_pretrained_kwargs) + + def get_generate_stream_function(self, model: "TorchNNModule", model_path: str): + from fastchat.model.model_adapter import get_generate_stream_function + + return get_generate_stream_function(model, model_path) + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> "Conversation": + return self._adapter.get_default_conv_template(model_path) + + +def get_conv_template(name: str) -> "Conversation": + """Get a conversation template.""" + from fastchat.conversation import get_conv_template + + return get_conv_template(name) + + +@cache +def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation": + try: + adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True) + return adapter.get_default_conv_template(model_name, model_path) + except Exception: + return None + + +@cache +def get_llm_model_adapter( + model_name: str, + model_path: str, + use_fastchat: bool = True, + use_fastchat_monkey_patch: bool = False, + model_type: str = None, +) -> LLMModelAdaper: + if model_type == ModelType.VLLM: + logger.info("Current model type is vllm, return VLLMModelAdaperWrapper") + return VLLMModelAdaperWrapper() + + must_use_old = any(m in model_name for m in _OLD_MODELS) + if use_fastchat and not must_use_old: + logger.info("Use fastcat adapter") + adapter = _get_fastchat_model_adapter( + model_name, + model_path, + _fastchat_get_adapter_monkey_patch, + use_fastchat_monkey_patch=use_fastchat_monkey_patch, + ) + return FastChatLLMModelAdaperWrapper(adapter) + else: + from pilot.model.adapter import ( + get_llm_model_adapter as _old_get_llm_model_adapter, + ) + from pilot.server.chat_adapter import get_llm_chat_adapter + + logger.info("Use DB-GPT old adapter") + return OldLLMModelAdaperWrapper( + _old_get_llm_model_adapter(model_name, model_path), + get_llm_chat_adapter(model_name, model_path), + ) + + +def _get_fastchat_model_adapter( + model_name: str, + model_path: str, + caller: Callable[[str], None] = None, + use_fastchat_monkey_patch: bool = False, +): + from fastchat.model import model_adapter + + _bak_get_model_adapter = model_adapter.get_model_adapter + try: + if use_fastchat_monkey_patch: + model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch + thread_local.model_name = model_name + if caller: + return caller(model_path) + finally: + del thread_local.model_name + model_adapter.get_model_adapter = _bak_get_model_adapter + + +def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None): + if not model_name: + if not hasattr(thread_local, "model_name"): + raise RuntimeError("fastchat get adapter monkey path need model_name") + model_name = thread_local.model_name + from fastchat.model.model_adapter import model_adapters + + for adapter in model_adapters: + if adapter.match(model_name): + logger.info( + f"Found llm model adapter with model name: {model_name}, {adapter}" + ) + return adapter + + model_path_basename = ( + None if not model_path else os.path.basename(os.path.normpath(model_path)) + ) + for adapter in model_adapters: + if model_path_basename and adapter.match(model_path_basename): + logger.info( + f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}" + ) + return adapter + + for adapter in model_adapters: + if model_path and adapter.match(model_path): + logger.info( + f"Found llm model adapter with model path: {model_path}, {adapter}" + ) + return adapter + + raise ValueError( + f"Invalid model adapter for model name {model_name} and model path {model_path}" + ) + + +def _dynamic_model_parser() -> Callable[[None], List[Type]]: + from pilot.utils.parameter_utils import _SimpleArgParser + from pilot.model.parameter import ( + EmbeddingModelParameters, + WorkerType, + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG, + ) + + pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type") + pre_args.parse() + model_name = pre_args.get("model_name") + model_path = pre_args.get("model_path") + worker_type = pre_args.get("worker_type") + model_type = pre_args.get("model_type") + if model_name is None and model_type != ModelType.VLLM: + return None + if worker_type == WorkerType.TEXT2VEC: + return [ + EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get( + model_name, EmbeddingModelParameters + ) + ] + + llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type) + param_class = llm_adapter.model_param_class() + return [param_class] + + +class VLLMModelAdaperWrapper(LLMModelAdaper): + """Wrapping vllm engine""" + + def model_type(self) -> str: + return ModelType.VLLM + + def model_param_class(self, model_type: str = None) -> ModelParameters: + import argparse + from vllm.engine.arg_utils import AsyncEngineArgs + + parser = argparse.ArgumentParser() + parser = AsyncEngineArgs.add_cli_args(parser) + parser.add_argument("--model_name", type=str, help="model name") + parser.add_argument( + "--model_path", + type=str, + help="local model path of the huggingface model to use", + ) + parser.add_argument("--model_type", type=str, help="model type") + parser.add_argument("--device", type=str, default=None, help="device") + # TODO parse prompt templete from `model_name` and `model_path` + parser.add_argument( + "--prompt_template", + type=str, + default=None, + help="Prompt template. If None, the prompt template is automatically determined from model path", + ) + + descs = _extract_parameter_details( + parser, + "pilot.model.parameter.VLLMModelParameters", + skip_names=["model"], + overwrite_default_values={"trust_remote_code": True}, + ) + return _build_parameter_class(descs) + + def load_from_params(self, params): + from vllm import AsyncLLMEngine + from vllm.engine.arg_utils import AsyncEngineArgs + import torch + + num_gpus = torch.cuda.device_count() + if num_gpus > 1 and hasattr(params, "tensor_parallel_size"): + setattr(params, "tensor_parallel_size", num_gpus) + logger.info( + f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}" + ) + + params = dataclasses.asdict(params) + params["model"] = params["model_path"] + attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)] + vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs} + # Set the attributes from the parsed arguments. + engine_args = AsyncEngineArgs(**vllm_engine_args_dict) + engine = AsyncLLMEngine.from_engine_args(engine_args) + return engine, engine.engine.tokenizer + + def support_async(self) -> bool: + return True + + def get_async_generate_stream_function(self, model, model_path: str): + from pilot.model.llm_out.vllm_llm import generate_stream + + return generate_stream + + def get_default_conv_template( + self, model_name: str, model_path: str + ) -> "Conversation": + return _auto_get_conv_template(model_name, model_path) + + +# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat. +# We also recommend that you modify it directly in the fastchat repository. +register_conv_template( + Conversation( + name="internlm-chat", + system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n", + roles=("<|User|>", "<|Bot|>"), + sep_style=SeparatorStyle.CHATINTERN, + sep="", + sep2="", + stop_token_ids=[1, 103028], + # TODO feedback stop_str to fastchat + stop_str="", + ), + override=True, +) diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index 6a92870b5..2a3df1835 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -64,6 +64,13 @@ class ModelWorkerParameters(BaseModelParameters): default=None, metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"}, ) + model_type: Optional[str] = field( + default="huggingface", + metadata={ + "help": "Model type: huggingface, llama.cpp, proxy and vllm", + "tags": "fixed", + }, + ) host: Optional[str] = field( default="0.0.0.0", metadata={"help": "Model worker deploy host"} ) @@ -163,7 +170,7 @@ class ModelParameters(BaseModelParameters): model_type: Optional[str] = field( default="huggingface", metadata={ - "help": "Model type, huggingface, llama.cpp and proxy", + "help": "Model type: huggingface, llama.cpp, proxy and vllm", "tags": "fixed", }, ) @@ -292,7 +299,7 @@ class ProxyModelParameters(BaseModelParameters): model_type: Optional[str] = field( default="proxy", metadata={ - "help": "Model type, huggingface, llama.cpp and proxy", + "help": "Model type: huggingface, llama.cpp, proxy and vllm", "tags": "fixed", }, ) diff --git a/pilot/server/base.py b/pilot/server/base.py index 8113b6fee..3b2d7010b 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -19,7 +19,7 @@ def signal_handler(sig, frame): os._exit(0) -def async_db_summery(system_app: SystemApp): +def async_db_summary(system_app: SystemApp): from pilot.summary.db_summary_client import DBSummaryClient client = DBSummaryClient(system_app=system_app) @@ -79,7 +79,7 @@ def _create_model_start_listener(system_app: SystemApp): print("begin run _add_app_startup_event") conn_manage = ConnectManager(system_app) cfg.LOCAL_DB_MANAGE = conn_manage - async_db_summery(system_app) + async_db_summary(system_app) return startup_event diff --git a/pilot/server/chat_adapter.py b/pilot/server/chat_adapter.py index 80f22effe..cb486021b 100644 --- a/pilot/server/chat_adapter.py +++ b/pilot/server/chat_adapter.py @@ -1,3 +1,7 @@ +""" +This code file will be deprecated in the future. +We have integrated fastchat. For details, see: pilot/model/model_adapter.py +""" #!/usr/bin/env python3 # -*- coding: utf-8 -*- diff --git a/pilot/server/llm_manage/__init__.py b/pilot/server/llm_manage/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/server/llm_manage/request/__init__.py b/pilot/server/llm_manage/request/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/summary/db_summary.py b/pilot/summary/db_summary.py index 30f5e2e48..86306a31d 100644 --- a/pilot/summary/db_summary.py +++ b/pilot/summary/db_summary.py @@ -1,18 +1,18 @@ class DBSummary: def __init__(self, name): self.name = name - self.summery = None + self.summary = None self.tables = [] self.metadata = str - def get_summery(self): - return self.summery + def get_summary(self): + return self.summary class TableSummary: def __init__(self, name): self.name = name - self.summery = None + self.summary = None self.fields = [] self.indexes = [] @@ -20,12 +20,12 @@ class TableSummary: class FieldSummary: def __init__(self, name): self.name = name - self.summery = None + self.summary = None self.data_type = None class IndexSummary: def __init__(self, name): self.name = name - self.summery = None + self.summary = None self.bind_fields = [] diff --git a/pilot/summary/db_summary_client.py b/pilot/summary/db_summary_client.py index 23597e0f0..6ba28afe7 100644 --- a/pilot/summary/db_summary_client.py +++ b/pilot/summary/db_summary_client.py @@ -47,13 +47,13 @@ class DBSummaryClient: "embeddings": embeddings, } embedding = StringEmbedding( - file_path=db_summary_client.get_summery(), + file_path=db_summary_client.get_summary(), vector_store_config=vector_store_config, ) self.init_db_profile(db_summary_client, dbname, embeddings) if not embedding.vector_name_exist(): if CFG.SUMMARY_CONFIG == "FAST": - for vector_table_info in db_summary_client.get_summery(): + for vector_table_info in db_summary_client.get_summary(): embedding = StringEmbedding( vector_table_info, vector_store_config, @@ -61,7 +61,7 @@ class DBSummaryClient: embedding.source_embedding() else: embedding = StringEmbedding( - file_path=db_summary_client.get_summery(), + file_path=db_summary_client.get_summary(), vector_store_config=vector_store_config, ) embedding.source_embedding() @@ -144,8 +144,8 @@ class DBSummaryClient: vector_store_config=vector_store_config, embedding_factory=embedding_factory, ) - table_summery = knowledge_embedding_client.similar_search(query, 1) - related_table_summaries.append(table_summery[0].page_content) + table_summary = knowledge_embedding_client.similar_search(query, 1) + related_table_summaries.append(table_summary[0].page_content) return related_table_summaries def init_db_summary(self): @@ -169,7 +169,7 @@ class DBSummaryClient: "embeddings": embeddings, } embedding = StringEmbedding( - file_path=db_summary_client.get_db_summery(), + file_path=db_summary_client.get_db_summary(), vector_store_config=profile_store_config, ) if not embedding.vector_name_exist(): diff --git a/pilot/summary/rdbms_db_summary.py b/pilot/summary/rdbms_db_summary.py index 25a4b23c3..95c603df8 100644 --- a/pilot/summary/rdbms_db_summary.py +++ b/pilot/summary/rdbms_db_summary.py @@ -12,7 +12,7 @@ class RdbmsSummary(DBSummary): def __init__(self, name, type): self.name = name self.type = type - self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" + self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}""" self.tables = {} self.tables_info = [] self.vector_tables_info = [] @@ -48,7 +48,7 @@ class RdbmsSummary(DBSummary): for table_name in tables: table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map) - # self.tables[table_name] = table_summary.get_summery() + # self.tables[table_name] = table_summary.get_summary() self.tables[table_name] = table_summary.get_columns() self.table_columns_info.append(table_summary.get_columns()) # self.table_columns_json.append(table_summary.get_summary_json()) @@ -59,18 +59,18 @@ class RdbmsSummary(DBSummary): ) ) self.table_columns_json.append(table_profile) - # self.tables_info.append(table_summary.get_summery()) + # self.tables_info.append(table_summary.get_summary()) - def get_summery(self): + def get_summary(self): if CFG.SUMMARY_CONFIG == "FAST": return self.vector_tables_info else: - return self.summery.format( + return self.summary.format( name=self.name, type=self.type, table_info=";".join(self.tables_info) ) - def get_db_summery(self): - return self.summery.format( + def get_db_summary(self): + return self.summary.format( name=self.name, type=self.type, tables=";".join(self.vector_tables_info), @@ -94,8 +94,8 @@ class RdbmsTableSummary(TableSummary): def __init__(self, instance, dbname, name, comment_map): self.name = name self.dbname = dbname - self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" - self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" + self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}""" + self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}""" self.fields = [] self.fields_info = [] self.indexes = [] @@ -107,19 +107,19 @@ class RdbmsTableSummary(TableSummary): for field in fields: field_summary = RdbmsFieldsSummary(field) self.fields.append(field_summary) - self.fields_info.append(field_summary.get_summery()) + self.fields_info.append(field_summary.get_summary()) field_names.append(field[0]) - self.column_summery = """{name}({columns_info})""".format( + self.column_summary = """{name}({columns_info})""".format( name=name, columns_info=",".join(field_names) ) for index in indexes: index_summary = RdbmsIndexSummary(index) self.indexes.append(index_summary) - self.indexes_info.append(index_summary.get_summery()) + self.indexes_info.append(index_summary.get_summary()) - self.json_summery = self.json_summery_template.format( + self.json_summary = self.json_summary_template.format( name=name, comment=comment_map[name], fields=self.fields_info, @@ -128,8 +128,8 @@ class RdbmsTableSummary(TableSummary): rows=1000, ) - def get_summery(self): - return self.summery.format( + def get_summary(self): + return self.summary.format( name=self.name, dbname=self.dbname, fields=";".join(self.fields_info), @@ -137,10 +137,10 @@ class RdbmsTableSummary(TableSummary): ) def get_columns(self): - return self.column_summery + return self.column_summary def get_summary_json(self): - return self.json_summery + return self.json_summary class RdbmsFieldsSummary(FieldSummary): @@ -148,14 +148,14 @@ class RdbmsFieldsSummary(FieldSummary): def __init__(self, field): self.name = field[0] - # self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ - # self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" + # self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """ + # self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}""" self.data_type = field[1] self.default_value = field[2] self.is_nullable = field[3] self.comment = field[4] - def get_summery(self): + def get_summary(self): return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format( name=self.name, data_type=self.data_type, @@ -170,11 +170,11 @@ class RdbmsIndexSummary(IndexSummary): def __init__(self, index): self.name = index[0] - # self.summery = """index name:{name}, index bind columns:{bind_fields}""" - self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}' + # self.summary = """index name:{name}, index bind columns:{bind_fields}""" + self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}' self.bind_fields = index[1] - def get_summery(self): - return self.summery_template.format( + def get_summary(self): + return self.summary_template.format( name=self.name, bind_fields=self.bind_fields ) diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index 5be747e23..8acba8881 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -86,17 +86,7 @@ class BaseParameters: return updated def __str__(self) -> str: - class_name = self.__class__.__name__ - parameters = [ - f"\n\n=========================== {class_name} ===========================\n" - ] - for field_info in fields(self): - value = _get_simple_privacy_field_value(self, field_info) - parameters.append(f"{field_info.name}: {value}") - parameters.append( - "\n======================================================================\n\n" - ) - return "\n".join(parameters) + return _get_dataclass_print_str(self) def to_command_args(self, args_prefix: str = "--") -> List[str]: """Convert the fields of the dataclass to a list of command line arguments. @@ -110,6 +100,20 @@ class BaseParameters: return _dict_to_command_args(asdict(self), args_prefix=args_prefix) +def _get_dataclass_print_str(obj): + class_name = obj.__class__.__name__ + parameters = [ + f"\n\n=========================== {class_name} ===========================\n" + ] + for field_info in fields(obj): + value = _get_simple_privacy_field_value(obj, field_info) + parameters.append(f"{field_info.name}: {value}") + parameters.append( + "\n======================================================================\n\n" + ) + return "\n".join(parameters) + + def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]: """Convert dict to a list of command line arguments @@ -493,9 +497,10 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type: if not desc: raise ValueError("Parameter descriptions cant be empty") param_class_str = desc[0].param_class - param_class = import_from_string(param_class_str, ignore_import_error=True) - if param_class: - return param_class + if param_class_str: + param_class = import_from_string(param_class_str, ignore_import_error=True) + if param_class: + return param_class module_name, _, class_name = param_class_str.rpartition(".") fields_dict = {} # This will store field names and their default values or field() @@ -520,6 +525,71 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type: return result_class +def _extract_parameter_details( + parser: argparse.ArgumentParser, + param_class: str = None, + skip_names: List[str] = None, + overwrite_default_values: Dict = {}, +) -> List[ParameterDescription]: + descriptions = [] + + for action in parser._actions: + if ( + action.default == argparse.SUPPRESS + ): # typically this means the argument was not provided + continue + + # determine parameter class (store_true/store_false are flags) + flag_or_option = ( + "flag" if isinstance(action, argparse._StoreConstAction) else "option" + ) + + # extract parameter name (use the first option string, typically the long form) + param_name = action.option_strings[0] if action.option_strings else action.dest + if param_name.startswith("--"): + param_name = param_name[2:] + if param_name.startswith("-"): + param_name = param_name[1:] + + param_name = param_name.replace("-", "_") + + if skip_names and param_name in skip_names: + continue + + # gather other details + default_value = action.default + if param_name in overwrite_default_values: + default_value = overwrite_default_values[param_name] + arg_type = ( + action.type if not callable(action.type) else str(action.type.__name__) + ) + description = action.help + + # determine if the argument is required + required = action.required + + # extract valid values for choices, if provided + valid_values = action.choices if action.choices is not None else None + + # set ext_metadata as an empty dict for now, can be updated later if needed + ext_metadata = {} + + descriptions.append( + ParameterDescription( + param_class=param_class, + param_name=param_name, + param_type=arg_type, + default_value=default_value, + description=description, + required=required, + valid_values=valid_values, + ext_metadata=ext_metadata, + ) + ) + + return descriptions + + class _SimpleArgParser: def __init__(self, *args): self.params = {arg.replace("_", "-"): None for arg in args} diff --git a/pilot/utils/tests/__init__.py b/pilot/utils/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pilot/utils/tests/test_parameter_utils.py b/pilot/utils/tests/test_parameter_utils.py new file mode 100644 index 000000000..cf95b4ae7 --- /dev/null +++ b/pilot/utils/tests/test_parameter_utils.py @@ -0,0 +1,81 @@ +import argparse +import pytest +from pilot.utils.parameter_utils import _extract_parameter_details + + +def create_parser(): + parser = argparse.ArgumentParser() + return parser + + +@pytest.mark.parametrize( + "argument, expected_param_name, default_value, param_type, expected_param_type, description", + [ + ("--option", "option", "value", str, "str", "An option argument"), + ("-option", "option", "value", str, "str", "An option argument"), + ("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"), + ("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"), + ], +) +def test_extract_parameter_details_option_argument( + argument, + expected_param_name, + default_value, + param_type, + expected_param_type, + description, +): + parser = create_parser() + parser.add_argument( + argument, default=default_value, type=param_type, help=description + ) + descriptions = _extract_parameter_details(parser) + + assert len(descriptions) == 1 + desc = descriptions[0] + + assert desc.param_name == expected_param_name + assert desc.param_type == expected_param_type + assert desc.default_value == default_value + assert desc.description == description + assert desc.required == False + assert desc.valid_values is None + + +def test_extract_parameter_details_flag_argument(): + parser = create_parser() + parser.add_argument("--flag", action="store_true", help="A flag argument") + descriptions = _extract_parameter_details(parser) + + assert len(descriptions) == 1 + desc = descriptions[0] + + assert desc.param_name == "flag" + assert desc.description == "A flag argument" + assert desc.required == False + + +def test_extract_parameter_details_choice_argument(): + parser = create_parser() + parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument") + descriptions = _extract_parameter_details(parser) + + assert len(descriptions) == 1 + desc = descriptions[0] + + assert desc.param_name == "choice" + assert desc.valid_values == ["A", "B", "C"] + + +def test_extract_parameter_details_required_argument(): + parser = create_parser() + parser.add_argument( + "--required", required=True, type=int, help="A required argument" + ) + descriptions = _extract_parameter_details(parser) + + assert len(descriptions) == 1 + desc = descriptions[0] + + assert desc.param_name == "required" + assert desc.required == True diff --git a/requirements/dev-requirements.txt b/requirements/dev-requirements.txt index 00865cb52..eb54df339 100644 --- a/requirements/dev-requirements.txt +++ b/requirements/dev-requirements.txt @@ -9,4 +9,6 @@ pytest-mock pytest-recording pytesseract==0.3.10 # python code format -black \ No newline at end of file +black +# for git hooks +pre-commmit \ No newline at end of file diff --git a/setup.py b/setup.py index ee95f1338..28b71fca9 100644 --- a/setup.py +++ b/setup.py @@ -203,9 +203,9 @@ def get_cuda_version() -> str: def torch_requires( - torch_version: str = "2.0.0", - torchvision_version: str = "0.15.1", - torchaudio_version: str = "2.0.1", + torch_version: str = "2.0.1", + torchvision_version: str = "0.15.2", + torchaudio_version: str = "2.0.2", ): torch_pkgs = [ f"torch=={torch_version}", @@ -298,6 +298,7 @@ def core_requires(): ] setup_spec.extras["framework"] = [ + "fschat", "coloredlogs", "httpx", "sqlparse==0.4.4", @@ -396,12 +397,19 @@ def gpt4all_requires(): setup_spec.extras["gpt4all"] = ["gpt4all"] +def vllm_requires(): + """ + pip install "db-gpt[vllm]" + """ + setup_spec.extras["vllm"] = ["vllm"] + + def default_requires(): """ pip install "db-gpt[default]" """ setup_spec.extras["default"] = [ - "tokenizers==0.13.2", + "tokenizers==0.13.3", "accelerate>=0.20.3", "sentence-transformers", "protobuf==3.20.3", @@ -435,6 +443,7 @@ all_vector_store_requires() all_datasource_requires() openai_requires() gpt4all_requires() +vllm_requires() # must be last default_requires()