mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-01 16:18:27 +00:00
Merge remote-tracking branch 'origin/main' into feat_rag_graph
This commit is contained in:
commit
e2a1990696
4
.gitignore
vendored
4
.gitignore
vendored
@ -151,4 +151,6 @@ pilot/mock_datas/db-gpt-test.db.wal
|
||||
|
||||
logswebserver.log.*
|
||||
.history/*
|
||||
.plugin_env
|
||||
.plugin_env
|
||||
# Ignore for now
|
||||
thirdparty
|
@ -10,11 +10,11 @@ git clone https://github.com/<YOUR-GITHUB-USERNAME>/DB-GPT
|
||||
```
|
||||
3. Install the project requirements
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
pip install -r requirements/dev-requirements.txt
|
||||
```
|
||||
4. Install pre-commit hooks
|
||||
```
|
||||
pre-commit install
|
||||
pre-commit install --allow-missing-config
|
||||
```
|
||||
5. Create a new branch for your changes using the following command:
|
||||
|
||||
|
@ -41,6 +41,7 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
ipc: host
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
|
||||
docker run --gpus all -d \
|
||||
docker run --ipc host --gpus all -d \
|
||||
-p 5000:5000 \
|
||||
-e LOCAL_DB_TYPE=sqlite \
|
||||
-e LOCAL_DB_PATH=data/default_sqlite.db \
|
||||
|
@ -4,7 +4,7 @@
|
||||
PROXY_API_KEY="$PROXY_API_KEY"
|
||||
PROXY_SERVER_URL="${PROXY_SERVER_URL-'https://api.openai.com/v1/chat/completions'}"
|
||||
|
||||
docker run --gpus all -d \
|
||||
docker run --ipc host --gpus all -d \
|
||||
-p 5000:5000 \
|
||||
-e LOCAL_DB_TYPE=sqlite \
|
||||
-e LOCAL_DB_PATH=data/default_sqlite.db \
|
||||
|
@ -21,6 +21,7 @@ services:
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
ipc: host
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
|
@ -47,7 +47,7 @@ You can execute the command `bash docker/build_all_images.sh --help` to see more
|
||||
**Run with local model and SQLite database**
|
||||
|
||||
```bash
|
||||
docker run --gpus all -d \
|
||||
docker run --ipc host --gpus all -d \
|
||||
-p 5000:5000 \
|
||||
-e LOCAL_DB_TYPE=sqlite \
|
||||
-e LOCAL_DB_PATH=data/default_sqlite.db \
|
||||
@ -73,7 +73,7 @@ docker logs dbgpt -f
|
||||
**Run with local model and MySQL database**
|
||||
|
||||
```bash
|
||||
docker run --gpus all -d -p 3306:3306 \
|
||||
docker run --ipc host --gpus all -d -p 3306:3306 \
|
||||
-p 5000:5000 \
|
||||
-e LOCAL_DB_HOST=127.0.0.1 \
|
||||
-e LOCAL_DB_PASSWORD=aa123456 \
|
||||
|
@ -30,3 +30,4 @@ Multi LLMs Support, Supports multiple large language models, currently supportin
|
||||
|
||||
./llama/llama_cpp.md
|
||||
./quantization/quantization.md
|
||||
./vllm/vllm.md
|
||||
|
26
docs/getting_started/install/llm/vllm/vllm.md
Normal file
26
docs/getting_started/install/llm/vllm/vllm.md
Normal file
@ -0,0 +1,26 @@
|
||||
vLLM
|
||||
==================================
|
||||
|
||||
[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use library for LLM inference and serving.
|
||||
|
||||
## Running vLLM
|
||||
|
||||
### Installing Dependencies
|
||||
|
||||
vLLM is an optional dependency in DB-GPT, and you can manually install it using the following command:
|
||||
|
||||
```bash
|
||||
pip install -e ".[vllm]"
|
||||
```
|
||||
|
||||
### Modifying the Configuration File
|
||||
|
||||
Next, you can directly modify your `.env` file to enable vllm.
|
||||
|
||||
```env
|
||||
LLM_MODEL=vicuna-13b-v1.5
|
||||
MODEL_TYPE=vllm
|
||||
```
|
||||
You can view the models supported by vLLM [here](https://vllm.readthedocs.io/en/latest/models/supported_models.html#supported-models)
|
||||
|
||||
Then you can run it according to [Run](https://db-gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run).
|
@ -0,0 +1,79 @@
|
||||
# SOME DESCRIPTIVE TITLE.
|
||||
# Copyright (C) 2023, csunny
|
||||
# This file is distributed under the same license as the DB-GPT package.
|
||||
# FIRST AUTHOR <EMAIL@ADDRESS>, 2023.
|
||||
#
|
||||
#, fuzzy
|
||||
msgid ""
|
||||
msgstr ""
|
||||
"Project-Id-Version: DB-GPT 👏👏 0.3.9\n"
|
||||
"Report-Msgid-Bugs-To: \n"
|
||||
"POT-Creation-Date: 2023-10-09 19:46+0800\n"
|
||||
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
|
||||
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
|
||||
"Language: zh_CN\n"
|
||||
"Language-Team: zh_CN <LL@li.org>\n"
|
||||
"Plural-Forms: nplurals=1; plural=0;\n"
|
||||
"MIME-Version: 1.0\n"
|
||||
"Content-Type: text/plain; charset=utf-8\n"
|
||||
"Content-Transfer-Encoding: 8bit\n"
|
||||
"Generated-By: Babel 2.12.1\n"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:1
|
||||
#: 9193438ba52148a3b71f190beeb4ef42
|
||||
msgid "vLLM"
|
||||
msgstr ""
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:4
|
||||
#: c30d032965794d7e81636581324be45d
|
||||
msgid ""
|
||||
"[vLLM](https://github.com/vllm-project/vllm) is a fast and easy-to-use "
|
||||
"library for LLM inference and serving."
|
||||
msgstr "[vLLM](https://github.com/vllm-project/vllm) 是一个快速且易于使用的 LLM 推理和服务的库。"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:6
|
||||
#: b399c7268e0448cb893fbcb11a480849
|
||||
msgid "Running vLLM"
|
||||
msgstr "运行 vLLM"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:8
|
||||
#: 7bed52b8bac946069a24df9e94098df5
|
||||
msgid "Installing Dependencies"
|
||||
msgstr "安装依赖"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:10
|
||||
#: fd50a9f3e1b1459daa3b1a0cd610d1a3
|
||||
msgid ""
|
||||
"vLLM is an optional dependency in DB-GPT, and you can manually install it"
|
||||
" using the following command:"
|
||||
msgstr "vLLM 在 DB-GPT 是一个可选依赖, 你可以使用下面的命令手动安装它:"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:16
|
||||
#: 44b251bc6b2c41ebaad9fd5a6a204c7c
|
||||
msgid "Modifying the Configuration File"
|
||||
msgstr "修改配置文件"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:18
|
||||
#: 37f4e65148fa4339969265107b70b8fe
|
||||
msgid "Next, you can directly modify your `.env` file to enable vllm."
|
||||
msgstr "你可以直接修改你的 `.env` 文件。"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:24
|
||||
#: 15d79c9417d04e779fa00a08a05e30d7
|
||||
msgid ""
|
||||
"You can view the models supported by vLLM "
|
||||
"[here](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
|
||||
"#supported-models)"
|
||||
msgstr ""
|
||||
"你可以在 "
|
||||
"[这里](https://vllm.readthedocs.io/en/latest/models/supported_models.html"
|
||||
"#supported-models) 查看 vLLM 支持的模型。"
|
||||
|
||||
#: ../../getting_started/install/llm/vllm/vllm.md:26
|
||||
#: 28d90b1fdf6943d9969c8668a7c1094b
|
||||
msgid ""
|
||||
"Then you can run it according to [Run](https://db-"
|
||||
"gpt.readthedocs.io/en/latest/getting_started/install/deploy/deploy.html#run)."
|
||||
msgstr ""
|
||||
"然后你可以根据[运行]"
|
||||
"(https://db-gpt.readthedocs.io/projects/db-gpt-docs-zh-cn/zh_CN/latest/getting_started/install/deploy/deploy.html#run)来启动项目。"
|
@ -201,6 +201,9 @@ class Config(metaclass=Singleton):
|
||||
|
||||
self.SYSTEM_APP: Optional["SystemApp"] = None
|
||||
|
||||
### Temporary configuration
|
||||
self.USE_FASTCHAT: bool = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||
|
||||
def set_debug_mode(self, value: bool) -> None:
|
||||
"""Set the debug mode value"""
|
||||
self.debug_mode = value
|
||||
|
@ -76,6 +76,8 @@ LLM_MODEL_CONFIG = {
|
||||
"internlm-7b": os.path.join(MODEL_PATH, "internlm-chat-7b"),
|
||||
"internlm-7b-8k": os.path.join(MODEL_PATH, "internlm-chat-7b-8k"),
|
||||
"internlm-20b": os.path.join(MODEL_PATH, "internlm-20b-chat"),
|
||||
# For test now
|
||||
"opt-125m": os.path.join(MODEL_PATH, "opt-125m"),
|
||||
}
|
||||
|
||||
EMBEDDING_MODEL_CONFIG = {
|
||||
|
@ -29,7 +29,7 @@ class SeparatorStyle(Enum):
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
class OldConversation:
|
||||
"""This class keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
@ -81,7 +81,7 @@ class Conversation:
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
return OldConversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
@ -104,7 +104,7 @@ class Conversation:
|
||||
}
|
||||
|
||||
|
||||
conv_default = Conversation(
|
||||
conv_default = OldConversation(
|
||||
system=None,
|
||||
roles=("human", "ai"),
|
||||
messages=[],
|
||||
@ -148,7 +148,7 @@ conv_default = Conversation(
|
||||
# )
|
||||
|
||||
|
||||
conv_one_shot = Conversation(
|
||||
conv_one_shot = OldConversation(
|
||||
system="You are a DB-GPT. Please provide me with user input and all table information known in the database, so I can accurately query tables are involved in the user input. If there are multiple tables involved, I will separate them by comma. Here is an example:",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
messages=(
|
||||
@ -179,7 +179,7 @@ conv_one_shot = Conversation(
|
||||
sep="###",
|
||||
)
|
||||
|
||||
conv_vicuna_v1 = Conversation(
|
||||
conv_vicuna_v1 = OldConversation(
|
||||
system="A chat between a curious user and an artificial intelligence assistant. who very familiar with database related knowledge. "
|
||||
"The assistant gives helpful, detailed, professional and polite answers to the user's questions. ",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
@ -190,7 +190,7 @@ conv_vicuna_v1 = Conversation(
|
||||
sep2="</s>",
|
||||
)
|
||||
|
||||
auto_dbgpt_one_shot = Conversation(
|
||||
auto_dbgpt_one_shot = OldConversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about HackerNews by query `hackerbews` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
@ -263,7 +263,7 @@ auto_dbgpt_one_shot = Conversation(
|
||||
sep="###",
|
||||
)
|
||||
|
||||
auto_dbgpt_without_shot = Conversation(
|
||||
auto_dbgpt_without_shot = OldConversation(
|
||||
system="You are DB-GPT, an AI designed to answer questions about users by query `users` database in MySQL. "
|
||||
"Your decisions must always be made independently without seeking user assistance. Play to your strengths as an LLM and pursue simple strategies with no legal complications.",
|
||||
roles=("USER", "ASSISTANT"),
|
||||
|
@ -1,3 +1,7 @@
|
||||
"""
|
||||
This code file will be deprecated in the future.
|
||||
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
@ -13,6 +17,8 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from pilot.model.base import ModelType
|
||||
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
@ -26,15 +32,6 @@ logger = logging.getLogger(__name__)
|
||||
CFG = Config()
|
||||
|
||||
|
||||
class ModelType:
|
||||
""" "Type of model"""
|
||||
|
||||
HF = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
PROXY = "proxy"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
class BaseLLMAdaper:
|
||||
"""The Base class for multi model, in our project.
|
||||
We will support those model, which performance resemble ChatGPT"""
|
||||
@ -95,33 +92,6 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
if model_name is None:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
||||
|
||||
def _parse_model_param_class(model_name: str, model_path: str) -> ModelParameters:
|
||||
try:
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
|
@ -15,6 +15,16 @@ class Message(TypedDict):
|
||||
content: str
|
||||
|
||||
|
||||
class ModelType:
|
||||
""" "Type of model"""
|
||||
|
||||
HF = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInstance:
|
||||
"""Model instance info"""
|
||||
|
@ -404,7 +404,7 @@ def stop_model_controller(port: int):
|
||||
|
||||
|
||||
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from pilot.model.adapter import _dynamic_model_parser
|
||||
from pilot.model.model_adapter import _dynamic_model_parser
|
||||
|
||||
param_class = _dynamic_model_parser()
|
||||
fix_class = [ModelWorkerParameters]
|
||||
|
@ -1,26 +1,34 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, Iterator, List
|
||||
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper
|
||||
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||
from pilot.model.base import ModelOutput
|
||||
from pilot.model.loader import ModelLoader, _get_model_real_path
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.cluster.worker_base import ModelWorker
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter, BaseChatAdpter
|
||||
from pilot.utils.model_utils import _clear_model_cache
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_torch_imported = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
_torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultModelWorker(ModelWorker):
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self._model_params = None
|
||||
self.llm_adapter: BaseLLMAdaper = None
|
||||
self.llm_chat_adapter: BaseChatAdpter = None
|
||||
self.llm_adapter: LLMModelAdaper = None
|
||||
self._support_async = False
|
||||
|
||||
def load_worker(self, model_name: str, model_path: str, **kwargs) -> None:
|
||||
if model_path.endswith("/"):
|
||||
@ -29,18 +37,24 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.model_name = model_name
|
||||
self.model_path = model_path
|
||||
|
||||
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = kwargs.get("model_type")
|
||||
### Temporary configuration, fastchat will be used by default in the future.
|
||||
use_fastchat = os.getenv("USE_FASTCHAT", "True").lower() == "true"
|
||||
|
||||
self.llm_adapter = get_llm_model_adapter(
|
||||
self.model_name,
|
||||
self.model_path,
|
||||
use_fastchat=use_fastchat,
|
||||
model_type=model_type,
|
||||
)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
self._support_async = self.llm_adapter.support_async()
|
||||
|
||||
logger.info(
|
||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||
)
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
self.model_path
|
||||
)
|
||||
|
||||
self.ml: ModelLoader = ModelLoader(
|
||||
model_path=self.model_path, model_name=self.model_name
|
||||
)
|
||||
@ -50,6 +64,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
def model_param_class(self) -> ModelParameters:
|
||||
return self.param_cls
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return self._support_async
|
||||
|
||||
def parse_parameters(self, command_args: List[str] = None) -> ModelParameters:
|
||||
param_cls = self.model_param_class()
|
||||
model_args = EnvArgumentParser()
|
||||
@ -77,7 +94,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
model_params = self.parse_parameters(command_args)
|
||||
self._model_params = model_params
|
||||
logger.info(f"Begin load model, model params: {model_params}")
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(model_params)
|
||||
self.model, self.tokenizer = self.ml.loader_with_params(
|
||||
model_params, self.llm_adapter
|
||||
)
|
||||
|
||||
def stop(self) -> None:
|
||||
if not self.model:
|
||||
@ -90,51 +109,26 @@ class DefaultModelWorker(ModelWorker):
|
||||
_clear_model_cache(self._model_params.device)
|
||||
|
||||
def generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
torch_imported = False
|
||||
try:
|
||||
import torch
|
||||
|
||||
torch_imported = True
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
# params adaptation
|
||||
params, model_context = self.llm_chat_adapter.model_adaptation(
|
||||
params, self.ml.model_path, prompt_template=self.ml.prompt_template
|
||||
params, model_context, generate_stream_func = self._prepare_generate_stream(
|
||||
params
|
||||
)
|
||||
|
||||
previous_response = ""
|
||||
print("stream output:\n")
|
||||
for output in self.generate_stream_func(
|
||||
|
||||
for output in generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
):
|
||||
# Please do not open the output in production!
|
||||
# The gpt4all thread shares stdout with the parent process,
|
||||
# and opening it may affect the frontend output.
|
||||
incremental_output = output[len(previous_response) :]
|
||||
# print("output: ", output)
|
||||
print(incremental_output, end="", flush=True)
|
||||
previous_response = output
|
||||
# return some model context to dgt-server
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
model_output, incremental_output, output_str = self._handle_output(
|
||||
output, previous_response, model_context
|
||||
)
|
||||
previous_response = output_str
|
||||
yield model_output
|
||||
print(
|
||||
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
else:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
yield model_output
|
||||
yield self._handle_exception(e)
|
||||
|
||||
def generate(self, params: Dict) -> ModelOutput:
|
||||
"""Generate non stream result"""
|
||||
@ -145,3 +139,81 @@ class DefaultModelWorker(ModelWorker):
|
||||
|
||||
def embeddings(self, params: Dict) -> List[List[float]]:
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_generate_stream(self, params: Dict) -> Iterator[ModelOutput]:
|
||||
try:
|
||||
params, model_context, generate_stream_func = self._prepare_generate_stream(
|
||||
params
|
||||
)
|
||||
|
||||
previous_response = ""
|
||||
|
||||
async for output in generate_stream_func(
|
||||
self.model, self.tokenizer, params, get_device(), self.context_len
|
||||
):
|
||||
model_output, incremental_output, output_str = self._handle_output(
|
||||
output, previous_response, model_context
|
||||
)
|
||||
previous_response = output_str
|
||||
yield model_output
|
||||
print(
|
||||
f"\n\nfull stream output:\n{previous_response}\n\nmodel generate_stream params:\n{params}"
|
||||
)
|
||||
except Exception as e:
|
||||
yield self._handle_exception(e)
|
||||
|
||||
async def async_generate(self, params: Dict) -> ModelOutput:
|
||||
output = None
|
||||
async for out in self.async_generate_stream(params):
|
||||
output = out
|
||||
return output
|
||||
|
||||
def _prepare_generate_stream(self, params: Dict):
|
||||
params, model_context = self.llm_adapter.model_adaptation(
|
||||
params,
|
||||
self.model_name,
|
||||
self.model_path,
|
||||
prompt_template=self.ml.prompt_template,
|
||||
)
|
||||
stream_type = ""
|
||||
if self.support_async():
|
||||
generate_stream_func = self.llm_adapter.get_async_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
stream_type = "async "
|
||||
logger.info(
|
||||
"current generate stream function is asynchronous stream function"
|
||||
)
|
||||
else:
|
||||
generate_stream_func = self.llm_adapter.get_generate_stream_function(
|
||||
self.model, self.model_path
|
||||
)
|
||||
str_prompt = params.get("prompt")
|
||||
print(f"model prompt: \n\n{str_prompt}\n\n{stream_type}stream output:\n")
|
||||
return params, model_context, generate_stream_func
|
||||
|
||||
def _handle_output(self, output, previous_response, model_context):
|
||||
if isinstance(output, dict):
|
||||
finish_reason = output.get("finish_reason")
|
||||
output = output["text"]
|
||||
if finish_reason is not None:
|
||||
logger.info(f"finish_reason: {finish_reason}")
|
||||
incremental_output = output[len(previous_response) :]
|
||||
print(incremental_output, end="", flush=True)
|
||||
model_output = ModelOutput(
|
||||
text=output, error_code=0, model_context=model_context
|
||||
)
|
||||
return model_output, incremental_output, output
|
||||
|
||||
def _handle_exception(self, e):
|
||||
# Check if the exception is a torch.cuda.CudaError and if torch was imported.
|
||||
if _torch_imported and isinstance(e, torch.cuda.CudaError):
|
||||
model_output = ModelOutput(
|
||||
text="**GPU OutOfMemory, Please Refresh.**", error_code=0
|
||||
)
|
||||
else:
|
||||
model_output = ModelOutput(
|
||||
text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
|
||||
error_code=0,
|
||||
)
|
||||
return model_output
|
||||
|
@ -3,7 +3,9 @@ Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/conver
|
||||
|
||||
Conversation prompt templates.
|
||||
|
||||
TODO Using fastchat core package
|
||||
|
||||
This code file will be deprecated in the future.
|
||||
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""
|
||||
Fork from fastchat: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/inference.py
|
||||
|
||||
This code file will be deprecated in the future.
|
||||
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from typing import Optional
|
||||
|
||||
from pilot.configs.config import Config
|
||||
from pilot.conversation import Conversation
|
||||
|
||||
|
||||
# TODO Rewrite this
|
||||
@ -44,36 +43,6 @@ def retry_stream_api(
|
||||
return _wrapper
|
||||
|
||||
|
||||
# Overly simple abstraction util we create something better
|
||||
# simple retry mechanism when getting a rate error or a bad gateway
|
||||
def create_chat_competion(
|
||||
conv: Conversation,
|
||||
model: Optional[str] = None,
|
||||
temperature: float = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Create a chat completion using the Vicuna-13b
|
||||
|
||||
Args:
|
||||
messages(List[Message]): The messages to send to the chat completion
|
||||
model (str, optional): The model to use. Default to None.
|
||||
temperature (float, optional): The temperature to use. Defaults to 0.7.
|
||||
max_tokens (int, optional): The max tokens to use. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The response from the chat completion
|
||||
"""
|
||||
cfg = Config()
|
||||
if temperature is None:
|
||||
temperature = cfg.temperature
|
||||
|
||||
# TODO request vicuna model get response
|
||||
# convert vicuna message to chat completion.
|
||||
for plugin in cfg.plugins:
|
||||
if plugin.can_handle_chat_completion():
|
||||
pass
|
||||
|
||||
|
||||
class ChatIO(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def prompt_for_input(self, role: str) -> str:
|
||||
|
@ -1,7 +1,6 @@
|
||||
import torch
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
|
||||
from pilot.conversation import ROLE_ASSISTANT, ROLE_USER
|
||||
|
||||
|
||||
def guanaco_generate_output(model, tokenizer, params, device, context_len=2048):
|
||||
|
56
pilot/model/llm_out/vllm_llm.py
Normal file
56
pilot/model/llm_out/vllm_llm.py
Normal file
@ -0,0 +1,56 @@
|
||||
from typing import Dict
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, context_len: int
|
||||
):
|
||||
"""
|
||||
Adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/vllm_worker.py
|
||||
"""
|
||||
prompt = params["prompt"]
|
||||
request_id = params.pop("request_id") if "request_id" in params else random_uuid()
|
||||
temperature = float(params.get("temperature", 1.0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
echo = bool(params.get("echo", True))
|
||||
stop_str = params.get("stop", None)
|
||||
|
||||
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||
if tokenizer.eos_token_id is not None:
|
||||
stop_token_ids.append(tokenizer.eos_token_id)
|
||||
|
||||
# Handle stop_str
|
||||
stop = set()
|
||||
if isinstance(stop_str, str) and stop_str != "":
|
||||
stop.add(stop_str)
|
||||
elif isinstance(stop_str, list) and stop_str != []:
|
||||
stop.update(stop_str)
|
||||
|
||||
for tid in stop_token_ids:
|
||||
if tid is not None:
|
||||
stop.add(tokenizer.decode(tid))
|
||||
|
||||
# make sampling params in vllm
|
||||
top_p = max(top_p, 1e-5)
|
||||
if temperature <= 1e-5:
|
||||
top_p = 1.0
|
||||
sampling_params = SamplingParams(
|
||||
n=1,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
use_beam_search=False,
|
||||
stop=list(stop),
|
||||
max_tokens=max_new_tokens,
|
||||
)
|
||||
results_generator = model.generate(prompt, sampling_params, request_id)
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
if echo:
|
||||
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||
else:
|
||||
text_outputs = [output.text for output in request_output.outputs]
|
||||
text_outputs = " ".join(text_outputs)
|
||||
yield {"text": text_outputs, "error_code": 0, "usage": {}}
|
@ -146,7 +146,7 @@ def list_supported_models():
|
||||
def _list_supported_models(
|
||||
worker_type: str, model_config: Dict[str, str]
|
||||
) -> List[SupportedModel]:
|
||||
from pilot.model.adapter import get_llm_model_adapter
|
||||
from pilot.model.model_adapter import get_llm_model_adapter
|
||||
from pilot.model.parameter import ModelParameters
|
||||
from pilot.model.loader import _get_model_real_path
|
||||
|
||||
|
@ -3,9 +3,11 @@
|
||||
|
||||
from typing import Optional, Dict
|
||||
|
||||
from dataclasses import asdict
|
||||
import logging
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||
from pilot.model.base import ModelType
|
||||
from pilot.model.model_adapter import get_llm_model_adapter, LLMModelAdaper
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
@ -114,8 +116,9 @@ class ModelLoader:
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
def loader_with_params(self, model_params: ModelParameters):
|
||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
def loader_with_params(
|
||||
self, model_params: ModelParameters, llm_adapter: LLMModelAdaper
|
||||
):
|
||||
model_type = llm_adapter.model_type()
|
||||
self.prompt_template = model_params.prompt_template
|
||||
if model_type == ModelType.HF:
|
||||
@ -124,11 +127,13 @@ class ModelLoader:
|
||||
return llamacpp_loader(llm_adapter, model_params)
|
||||
elif model_type == ModelType.PROXY:
|
||||
return proxyllm_loader(llm_adapter, model_params)
|
||||
elif model_type == ModelType.VLLM:
|
||||
return llm_adapter.load_from_params(model_params)
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
|
||||
def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters):
|
||||
def huggingface_loader(llm_adapter: LLMModelAdaper, model_params: ModelParameters):
|
||||
import torch
|
||||
from pilot.model.compression import compress_module
|
||||
|
||||
@ -181,7 +186,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
|
||||
f"Current model {model_params.model_name} not supported quantization"
|
||||
)
|
||||
# default loader
|
||||
model, tokenizer = llm_adapter.loader(model_params.model_path, kwargs)
|
||||
model, tokenizer = llm_adapter.load(model_params.model_path, kwargs)
|
||||
|
||||
if model_params.load_8bit and num_gpus == 1 and tokenizer:
|
||||
# TODO merge current code into `load_huggingface_quantization_model`
|
||||
@ -204,7 +209,7 @@ def huggingface_loader(llm_adapter: BaseLLMAdaper, model_params: ModelParameters
|
||||
|
||||
|
||||
def load_huggingface_quantization_model(
|
||||
llm_adapter: BaseLLMAdaper,
|
||||
llm_adapter: LLMModelAdaper,
|
||||
model_params: ModelParameters,
|
||||
kwargs: Dict,
|
||||
max_memory: Dict[int, str],
|
||||
@ -339,7 +344,7 @@ def load_huggingface_quantization_model(
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParameters):
|
||||
def llamacpp_loader(llm_adapter: LLMModelAdaper, model_params: LlamaCppModelParameters):
|
||||
try:
|
||||
from pilot.model.llm.llama_cpp.llama_cpp import LlamaCppModel
|
||||
except ImportError as exc:
|
||||
@ -353,7 +358,7 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
|
||||
def proxyllm_loader(llm_adapter: LLMModelAdaper, model_params: ProxyModelParameters):
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
logger.info("Load proxyllm")
|
||||
|
431
pilot/model/model_adapter.py
Normal file
431
pilot/model/model_adapter.py
Normal file
@ -0,0 +1,431 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, List, Dict, Type, Tuple, TYPE_CHECKING
|
||||
import dataclasses
|
||||
import logging
|
||||
import threading
|
||||
import os
|
||||
from functools import cache
|
||||
from pilot.model.base import ModelType
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
from pilot.utils.parameter_utils import (
|
||||
_extract_parameter_details,
|
||||
_build_parameter_class,
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastchat.conversation import (
|
||||
Conversation,
|
||||
register_conv_template,
|
||||
SeparatorStyle,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ValueError(
|
||||
"Could not import python package: fschat "
|
||||
"Please install fastchat by command `pip install fschat` "
|
||||
) from exc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastchat.model.model_adapter import BaseModelAdapter
|
||||
from pilot.model.adapter import BaseLLMAdaper as OldBaseLLMAdaper
|
||||
from torch.nn import Module as TorchNNModule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
thread_local = threading.local()
|
||||
|
||||
|
||||
_OLD_MODELS = [
|
||||
"llama-cpp",
|
||||
"proxyllm",
|
||||
"gptj-6b",
|
||||
]
|
||||
|
||||
|
||||
class LLMModelAdaper:
|
||||
"""New Adapter for DB-GPT LLM models"""
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.HF
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
"""Get the startup parameters instance of the model"""
|
||||
model_type = model_type if model_type else self.model_type()
|
||||
if model_type == ModelType.LLAMA_CPP:
|
||||
return LlamaCppModelParameters
|
||||
elif model_type == ModelType.PROXY:
|
||||
return ProxyModelParameters
|
||||
return ModelParameters
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
"""Load model and tokenizer"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_from_params(self, params):
|
||||
"""Load the model and tokenizer according to the given parameters"""
|
||||
raise NotImplementedError
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
"""Get the asynchronous generate stream function of the model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
"""Get the default conv template"""
|
||||
raise NotImplementedError
|
||||
|
||||
def model_adaptation(
|
||||
self,
|
||||
params: Dict,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
prompt_template: str = None,
|
||||
) -> Tuple[Dict, Dict]:
|
||||
"""Params adaptation"""
|
||||
conv = self.get_default_conv_template(model_name, model_path)
|
||||
messages = params.get("messages")
|
||||
# Some model scontext to dbgpt server
|
||||
model_context = {"prompt_echo_len_char": -1, "has_format_prompt": False}
|
||||
|
||||
if messages:
|
||||
# Dict message to ModelMessage
|
||||
messages = [
|
||||
m if isinstance(m, ModelMessage) else ModelMessage(**m)
|
||||
for m in messages
|
||||
]
|
||||
params["messages"] = messages
|
||||
|
||||
if prompt_template:
|
||||
logger.info(f"Use prompt template {prompt_template} from config")
|
||||
conv = get_conv_template(prompt_template)
|
||||
if not conv or not messages:
|
||||
# Nothing to do
|
||||
logger.info(
|
||||
f"No conv from model_path {model_path} or no messages in params, {self}"
|
||||
)
|
||||
return params, model_context
|
||||
|
||||
conv = conv.copy()
|
||||
system_messages = []
|
||||
for message in messages:
|
||||
role, content = None, None
|
||||
if isinstance(message, ModelMessage):
|
||||
role = message.role
|
||||
content = message.content
|
||||
elif isinstance(message, dict):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
else:
|
||||
raise ValueError(f"Invalid message type: {message}")
|
||||
|
||||
if role == ModelMessageRoleType.SYSTEM:
|
||||
# Support for multiple system messages
|
||||
system_messages.append(content)
|
||||
elif role == ModelMessageRoleType.HUMAN:
|
||||
conv.append_message(conv.roles[0], content)
|
||||
elif role == ModelMessageRoleType.AI:
|
||||
conv.append_message(conv.roles[1], content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {role}")
|
||||
if system_messages:
|
||||
conv.set_system_message("".join(system_messages))
|
||||
|
||||
# Add a blank message for the assistant.
|
||||
conv.append_message(conv.roles[1], None)
|
||||
new_prompt = conv.get_prompt()
|
||||
# Overwrite the original prompt
|
||||
# TODO remote bos token and eos token from tokenizer_config.json of model
|
||||
prompt_echo_len_char = len(new_prompt.replace("</s>", "").replace("<s>", ""))
|
||||
model_context["prompt_echo_len_char"] = prompt_echo_len_char
|
||||
model_context["echo"] = params.get("echo", True)
|
||||
model_context["has_format_prompt"] = True
|
||||
params["prompt"] = new_prompt
|
||||
|
||||
# Overwrite model params:
|
||||
params["stop"] = conv.stop_str
|
||||
|
||||
return params, model_context
|
||||
|
||||
|
||||
class OldLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping old adapter, which may be removed later"""
|
||||
|
||||
def __init__(self, adapter: "OldBaseLLMAdaper", chat_adapter) -> None:
|
||||
self._adapter = adapter
|
||||
self._chat_adapter = chat_adapter
|
||||
|
||||
def model_type(self) -> str:
|
||||
return self._adapter.model_type()
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
return self._adapter.model_param_class(model_type)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return self._chat_adapter.get_conv_template(model_path)
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.loader(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model, model_path: str):
|
||||
return self._chat_adapter.get_generate_stream_func(model_path)
|
||||
|
||||
|
||||
class FastChatLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping fastchat adapter"""
|
||||
|
||||
def __init__(self, adapter: "BaseModelAdapter") -> None:
|
||||
self._adapter = adapter
|
||||
|
||||
def load(self, model_path: str, from_pretrained_kwargs: dict):
|
||||
return self._adapter.load_model(model_path, from_pretrained_kwargs)
|
||||
|
||||
def get_generate_stream_function(self, model: "TorchNNModule", model_path: str):
|
||||
from fastchat.model.model_adapter import get_generate_stream_function
|
||||
|
||||
return get_generate_stream_function(model, model_path)
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return self._adapter.get_default_conv_template(model_path)
|
||||
|
||||
|
||||
def get_conv_template(name: str) -> "Conversation":
|
||||
"""Get a conversation template."""
|
||||
from fastchat.conversation import get_conv_template
|
||||
|
||||
return get_conv_template(name)
|
||||
|
||||
|
||||
@cache
|
||||
def _auto_get_conv_template(model_name: str, model_path: str) -> "Conversation":
|
||||
try:
|
||||
adapter = get_llm_model_adapter(model_name, model_path, use_fastchat=True)
|
||||
return adapter.get_default_conv_template(model_name, model_path)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_llm_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
use_fastchat: bool = True,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
model_type: str = None,
|
||||
) -> LLMModelAdaper:
|
||||
if model_type == ModelType.VLLM:
|
||||
logger.info("Current model type is vllm, return VLLMModelAdaperWrapper")
|
||||
return VLLMModelAdaperWrapper()
|
||||
|
||||
must_use_old = any(m in model_name for m in _OLD_MODELS)
|
||||
if use_fastchat and not must_use_old:
|
||||
logger.info("Use fastcat adapter")
|
||||
adapter = _get_fastchat_model_adapter(
|
||||
model_name,
|
||||
model_path,
|
||||
_fastchat_get_adapter_monkey_patch,
|
||||
use_fastchat_monkey_patch=use_fastchat_monkey_patch,
|
||||
)
|
||||
return FastChatLLMModelAdaperWrapper(adapter)
|
||||
else:
|
||||
from pilot.model.adapter import (
|
||||
get_llm_model_adapter as _old_get_llm_model_adapter,
|
||||
)
|
||||
from pilot.server.chat_adapter import get_llm_chat_adapter
|
||||
|
||||
logger.info("Use DB-GPT old adapter")
|
||||
return OldLLMModelAdaperWrapper(
|
||||
_old_get_llm_model_adapter(model_name, model_path),
|
||||
get_llm_chat_adapter(model_name, model_path),
|
||||
)
|
||||
|
||||
|
||||
def _get_fastchat_model_adapter(
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
caller: Callable[[str], None] = None,
|
||||
use_fastchat_monkey_patch: bool = False,
|
||||
):
|
||||
from fastchat.model import model_adapter
|
||||
|
||||
_bak_get_model_adapter = model_adapter.get_model_adapter
|
||||
try:
|
||||
if use_fastchat_monkey_patch:
|
||||
model_adapter.get_model_adapter = _fastchat_get_adapter_monkey_patch
|
||||
thread_local.model_name = model_name
|
||||
if caller:
|
||||
return caller(model_path)
|
||||
finally:
|
||||
del thread_local.model_name
|
||||
model_adapter.get_model_adapter = _bak_get_model_adapter
|
||||
|
||||
|
||||
def _fastchat_get_adapter_monkey_patch(model_path: str, model_name: str = None):
|
||||
if not model_name:
|
||||
if not hasattr(thread_local, "model_name"):
|
||||
raise RuntimeError("fastchat get adapter monkey path need model_name")
|
||||
model_name = thread_local.model_name
|
||||
from fastchat.model.model_adapter import model_adapters
|
||||
|
||||
for adapter in model_adapters:
|
||||
if adapter.match(model_name):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model name: {model_name}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
model_path_basename = (
|
||||
None if not model_path else os.path.basename(os.path.normpath(model_path))
|
||||
)
|
||||
for adapter in model_adapters:
|
||||
if model_path_basename and adapter.match(model_path_basename):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path} and base name: {model_path_basename}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
for adapter in model_adapters:
|
||||
if model_path and adapter.match(model_path):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
||||
)
|
||||
return adapter
|
||||
|
||||
raise ValueError(
|
||||
f"Invalid model adapter for model name {model_name} and model path {model_path}"
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
from pilot.model.parameter import (
|
||||
EmbeddingModelParameters,
|
||||
WorkerType,
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG,
|
||||
)
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path", "worker_type", "model_type")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
worker_type = pre_args.get("worker_type")
|
||||
model_type = pre_args.get("model_type")
|
||||
if model_name is None and model_type != ModelType.VLLM:
|
||||
return None
|
||||
if worker_type == WorkerType.TEXT2VEC:
|
||||
return [
|
||||
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG.get(
|
||||
model_name, EmbeddingModelParameters
|
||||
)
|
||||
]
|
||||
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path, model_type=model_type)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
||||
|
||||
class VLLMModelAdaperWrapper(LLMModelAdaper):
|
||||
"""Wrapping vllm engine"""
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.VLLM
|
||||
|
||||
def model_param_class(self, model_type: str = None) -> ModelParameters:
|
||||
import argparse
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
parser.add_argument("--model_name", type=str, help="model name")
|
||||
parser.add_argument(
|
||||
"--model_path",
|
||||
type=str,
|
||||
help="local model path of the huggingface model to use",
|
||||
)
|
||||
parser.add_argument("--model_type", type=str, help="model type")
|
||||
parser.add_argument("--device", type=str, default=None, help="device")
|
||||
# TODO parse prompt templete from `model_name` and `model_path`
|
||||
parser.add_argument(
|
||||
"--prompt_template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Prompt template. If None, the prompt template is automatically determined from model path",
|
||||
)
|
||||
|
||||
descs = _extract_parameter_details(
|
||||
parser,
|
||||
"pilot.model.parameter.VLLMModelParameters",
|
||||
skip_names=["model"],
|
||||
overwrite_default_values={"trust_remote_code": True},
|
||||
)
|
||||
return _build_parameter_class(descs)
|
||||
|
||||
def load_from_params(self, params):
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
import torch
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1 and hasattr(params, "tensor_parallel_size"):
|
||||
setattr(params, "tensor_parallel_size", num_gpus)
|
||||
logger.info(
|
||||
f"Start vllm AsyncLLMEngine with args: {_get_dataclass_print_str(params)}"
|
||||
)
|
||||
|
||||
params = dataclasses.asdict(params)
|
||||
params["model"] = params["model_path"]
|
||||
attrs = [attr.name for attr in dataclasses.fields(AsyncEngineArgs)]
|
||||
vllm_engine_args_dict = {attr: params.get(attr) for attr in attrs}
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = AsyncEngineArgs(**vllm_engine_args_dict)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
return engine, engine.engine.tokenizer
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_async_generate_stream_function(self, model, model_path: str):
|
||||
from pilot.model.llm_out.vllm_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> "Conversation":
|
||||
return _auto_get_conv_template(model_name, model_path)
|
||||
|
||||
|
||||
# Covering the configuration of fastcaht, we will regularly feedback the code here to fastchat.
|
||||
# We also recommend that you modify it directly in the fastchat repository.
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="internlm-chat",
|
||||
system_message="A chat between a curious <|User|> and an <|Bot|>. The <|Bot|> gives helpful, detailed, and polite answers to the <|User|>'s questions.\n\n",
|
||||
roles=("<|User|>", "<|Bot|>"),
|
||||
sep_style=SeparatorStyle.CHATINTERN,
|
||||
sep="<eoh>",
|
||||
sep2="<eoa>",
|
||||
stop_token_ids=[1, 103028],
|
||||
# TODO feedback stop_str to fastchat
|
||||
stop_str="<eoa>",
|
||||
),
|
||||
override=True,
|
||||
)
|
@ -64,6 +64,13 @@ class ModelWorkerParameters(BaseModelParameters):
|
||||
default=None,
|
||||
metadata={"help": "Model worker class, pilot.model.cluster.DefaultModelWorker"},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default="huggingface",
|
||||
metadata={
|
||||
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Model worker deploy host"}
|
||||
)
|
||||
@ -163,7 +170,7 @@ class ModelParameters(BaseModelParameters):
|
||||
model_type: Optional[str] = field(
|
||||
default="huggingface",
|
||||
metadata={
|
||||
"help": "Model type, huggingface, llama.cpp and proxy",
|
||||
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
@ -292,7 +299,7 @@ class ProxyModelParameters(BaseModelParameters):
|
||||
model_type: Optional[str] = field(
|
||||
default="proxy",
|
||||
metadata={
|
||||
"help": "Model type, huggingface, llama.cpp and proxy",
|
||||
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
|
@ -19,7 +19,7 @@ def signal_handler(sig, frame):
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def async_db_summery(system_app: SystemApp):
|
||||
def async_db_summary(system_app: SystemApp):
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
client = DBSummaryClient(system_app=system_app)
|
||||
@ -79,7 +79,7 @@ def _create_model_start_listener(system_app: SystemApp):
|
||||
print("begin run _add_app_startup_event")
|
||||
conn_manage = ConnectManager(system_app)
|
||||
cfg.LOCAL_DB_MANAGE = conn_manage
|
||||
async_db_summery(system_app)
|
||||
async_db_summary(system_app)
|
||||
|
||||
return startup_event
|
||||
|
||||
|
@ -1,3 +1,7 @@
|
||||
"""
|
||||
This code file will be deprecated in the future.
|
||||
We have integrated fastchat. For details, see: pilot/model/model_adapter.py
|
||||
"""
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
0
pilot/server/llm_manage/__init__.py
Normal file
0
pilot/server/llm_manage/__init__.py
Normal file
0
pilot/server/llm_manage/request/__init__.py
Normal file
0
pilot/server/llm_manage/request/__init__.py
Normal file
@ -1,18 +1,18 @@
|
||||
class DBSummary:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.summery = None
|
||||
self.summary = None
|
||||
self.tables = []
|
||||
self.metadata = str
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery
|
||||
def get_summary(self):
|
||||
return self.summary
|
||||
|
||||
|
||||
class TableSummary:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.summery = None
|
||||
self.summary = None
|
||||
self.fields = []
|
||||
self.indexes = []
|
||||
|
||||
@ -20,12 +20,12 @@ class TableSummary:
|
||||
class FieldSummary:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.summery = None
|
||||
self.summary = None
|
||||
self.data_type = None
|
||||
|
||||
|
||||
class IndexSummary:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.summery = None
|
||||
self.summary = None
|
||||
self.bind_fields = []
|
||||
|
@ -47,13 +47,13 @@ class DBSummaryClient:
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_summery(),
|
||||
file_path=db_summary_client.get_summary(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
self.init_db_profile(db_summary_client, dbname, embeddings)
|
||||
if not embedding.vector_name_exist():
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
for vector_table_info in db_summary_client.get_summery():
|
||||
for vector_table_info in db_summary_client.get_summary():
|
||||
embedding = StringEmbedding(
|
||||
vector_table_info,
|
||||
vector_store_config,
|
||||
@ -61,7 +61,7 @@ class DBSummaryClient:
|
||||
embedding.source_embedding()
|
||||
else:
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_summery(),
|
||||
file_path=db_summary_client.get_summary(),
|
||||
vector_store_config=vector_store_config,
|
||||
)
|
||||
embedding.source_embedding()
|
||||
@ -144,8 +144,8 @@ class DBSummaryClient:
|
||||
vector_store_config=vector_store_config,
|
||||
embedding_factory=embedding_factory,
|
||||
)
|
||||
table_summery = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summery[0].page_content)
|
||||
table_summary = knowledge_embedding_client.similar_search(query, 1)
|
||||
related_table_summaries.append(table_summary[0].page_content)
|
||||
return related_table_summaries
|
||||
|
||||
def init_db_summary(self):
|
||||
@ -169,7 +169,7 @@ class DBSummaryClient:
|
||||
"embeddings": embeddings,
|
||||
}
|
||||
embedding = StringEmbedding(
|
||||
file_path=db_summary_client.get_db_summery(),
|
||||
file_path=db_summary_client.get_db_summary(),
|
||||
vector_store_config=profile_store_config,
|
||||
)
|
||||
if not embedding.vector_name_exist():
|
||||
|
@ -12,7 +12,7 @@ class RdbmsSummary(DBSummary):
|
||||
def __init__(self, name, type):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.summery = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
self.summary = """{{"database_name": "{name}", "type": "{type}", "tables": "{tables}", "qps": "{qps}", "tps": {tps}}}"""
|
||||
self.tables = {}
|
||||
self.tables_info = []
|
||||
self.vector_tables_info = []
|
||||
@ -48,7 +48,7 @@ class RdbmsSummary(DBSummary):
|
||||
|
||||
for table_name in tables:
|
||||
table_summary = RdbmsTableSummary(self.db, name, table_name, comment_map)
|
||||
# self.tables[table_name] = table_summary.get_summery()
|
||||
# self.tables[table_name] = table_summary.get_summary()
|
||||
self.tables[table_name] = table_summary.get_columns()
|
||||
self.table_columns_info.append(table_summary.get_columns())
|
||||
# self.table_columns_json.append(table_summary.get_summary_json())
|
||||
@ -59,18 +59,18 @@ class RdbmsSummary(DBSummary):
|
||||
)
|
||||
)
|
||||
self.table_columns_json.append(table_profile)
|
||||
# self.tables_info.append(table_summary.get_summery())
|
||||
# self.tables_info.append(table_summary.get_summary())
|
||||
|
||||
def get_summery(self):
|
||||
def get_summary(self):
|
||||
if CFG.SUMMARY_CONFIG == "FAST":
|
||||
return self.vector_tables_info
|
||||
else:
|
||||
return self.summery.format(
|
||||
return self.summary.format(
|
||||
name=self.name, type=self.type, table_info=";".join(self.tables_info)
|
||||
)
|
||||
|
||||
def get_db_summery(self):
|
||||
return self.summery.format(
|
||||
def get_db_summary(self):
|
||||
return self.summary.format(
|
||||
name=self.name,
|
||||
type=self.type,
|
||||
tables=";".join(self.vector_tables_info),
|
||||
@ -94,8 +94,8 @@ class RdbmsTableSummary(TableSummary):
|
||||
def __init__(self, instance, dbname, name, comment_map):
|
||||
self.name = name
|
||||
self.dbname = dbname
|
||||
self.summery = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
|
||||
self.json_summery_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
|
||||
self.summary = """database name:{dbname}, table name:{name}, have columns info: {fields}, have indexes info: {indexes}"""
|
||||
self.json_summary_template = """{{"table_name": "{name}", "comment": "{comment}", "columns": "{fields}", "indexes": "{indexes}", "size_in_bytes": {size_in_bytes}, "rows": {rows}}}"""
|
||||
self.fields = []
|
||||
self.fields_info = []
|
||||
self.indexes = []
|
||||
@ -107,19 +107,19 @@ class RdbmsTableSummary(TableSummary):
|
||||
for field in fields:
|
||||
field_summary = RdbmsFieldsSummary(field)
|
||||
self.fields.append(field_summary)
|
||||
self.fields_info.append(field_summary.get_summery())
|
||||
self.fields_info.append(field_summary.get_summary())
|
||||
field_names.append(field[0])
|
||||
|
||||
self.column_summery = """{name}({columns_info})""".format(
|
||||
self.column_summary = """{name}({columns_info})""".format(
|
||||
name=name, columns_info=",".join(field_names)
|
||||
)
|
||||
|
||||
for index in indexes:
|
||||
index_summary = RdbmsIndexSummary(index)
|
||||
self.indexes.append(index_summary)
|
||||
self.indexes_info.append(index_summary.get_summery())
|
||||
self.indexes_info.append(index_summary.get_summary())
|
||||
|
||||
self.json_summery = self.json_summery_template.format(
|
||||
self.json_summary = self.json_summary_template.format(
|
||||
name=name,
|
||||
comment=comment_map[name],
|
||||
fields=self.fields_info,
|
||||
@ -128,8 +128,8 @@ class RdbmsTableSummary(TableSummary):
|
||||
rows=1000,
|
||||
)
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery.format(
|
||||
def get_summary(self):
|
||||
return self.summary.format(
|
||||
name=self.name,
|
||||
dbname=self.dbname,
|
||||
fields=";".join(self.fields_info),
|
||||
@ -137,10 +137,10 @@ class RdbmsTableSummary(TableSummary):
|
||||
)
|
||||
|
||||
def get_columns(self):
|
||||
return self.column_summery
|
||||
return self.column_summary
|
||||
|
||||
def get_summary_json(self):
|
||||
return self.json_summery
|
||||
return self.json_summary
|
||||
|
||||
|
||||
class RdbmsFieldsSummary(FieldSummary):
|
||||
@ -148,14 +148,14 @@ class RdbmsFieldsSummary(FieldSummary):
|
||||
|
||||
def __init__(self, field):
|
||||
self.name = field[0]
|
||||
# self.summery = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
|
||||
# self.summery = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
|
||||
# self.summary = """column name:{name}, column data type:{data_type}, is nullable:{is_nullable}, default value is:{default_value}, comment is:{comment} """
|
||||
# self.summary = """{"name": {name}, "type": {data_type}, "is_primary_key": {is_nullable}, "comment":{comment}, "default":{default_value}}"""
|
||||
self.data_type = field[1]
|
||||
self.default_value = field[2]
|
||||
self.is_nullable = field[3]
|
||||
self.comment = field[4]
|
||||
|
||||
def get_summery(self):
|
||||
def get_summary(self):
|
||||
return '{{"name": "{name}", "type": "{data_type}", "is_primary_key": "{is_nullable}", "comment": "{comment}", "default": "{default_value}"}}'.format(
|
||||
name=self.name,
|
||||
data_type=self.data_type,
|
||||
@ -170,11 +170,11 @@ class RdbmsIndexSummary(IndexSummary):
|
||||
|
||||
def __init__(self, index):
|
||||
self.name = index[0]
|
||||
# self.summery = """index name:{name}, index bind columns:{bind_fields}"""
|
||||
self.summery_template = '{{"name": "{name}", "columns": {bind_fields}}}'
|
||||
# self.summary = """index name:{name}, index bind columns:{bind_fields}"""
|
||||
self.summary_template = '{{"name": "{name}", "columns": {bind_fields}}}'
|
||||
self.bind_fields = index[1]
|
||||
|
||||
def get_summery(self):
|
||||
return self.summery_template.format(
|
||||
def get_summary(self):
|
||||
return self.summary_template.format(
|
||||
name=self.name, bind_fields=self.bind_fields
|
||||
)
|
||||
|
@ -86,17 +86,7 @@ class BaseParameters:
|
||||
return updated
|
||||
|
||||
def __str__(self) -> str:
|
||||
class_name = self.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = _get_simple_privacy_field_value(self, field_info)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
return _get_dataclass_print_str(self)
|
||||
|
||||
def to_command_args(self, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert the fields of the dataclass to a list of command line arguments.
|
||||
@ -110,6 +100,20 @@ class BaseParameters:
|
||||
return _dict_to_command_args(asdict(self), args_prefix=args_prefix)
|
||||
|
||||
|
||||
def _get_dataclass_print_str(obj):
|
||||
class_name = obj.__class__.__name__
|
||||
parameters = [
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(obj):
|
||||
value = _get_simple_privacy_field_value(obj, field_info)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
)
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
def _dict_to_command_args(obj: Dict, args_prefix: str = "--") -> List[str]:
|
||||
"""Convert dict to a list of command line arguments
|
||||
|
||||
@ -493,9 +497,10 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
if not desc:
|
||||
raise ValueError("Parameter descriptions cant be empty")
|
||||
param_class_str = desc[0].param_class
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
if param_class_str:
|
||||
param_class = import_from_string(param_class_str, ignore_import_error=True)
|
||||
if param_class:
|
||||
return param_class
|
||||
module_name, _, class_name = param_class_str.rpartition(".")
|
||||
|
||||
fields_dict = {} # This will store field names and their default values or field()
|
||||
@ -520,6 +525,71 @@ def _build_parameter_class(desc: List[ParameterDescription]) -> Type:
|
||||
return result_class
|
||||
|
||||
|
||||
def _extract_parameter_details(
|
||||
parser: argparse.ArgumentParser,
|
||||
param_class: str = None,
|
||||
skip_names: List[str] = None,
|
||||
overwrite_default_values: Dict = {},
|
||||
) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
|
||||
for action in parser._actions:
|
||||
if (
|
||||
action.default == argparse.SUPPRESS
|
||||
): # typically this means the argument was not provided
|
||||
continue
|
||||
|
||||
# determine parameter class (store_true/store_false are flags)
|
||||
flag_or_option = (
|
||||
"flag" if isinstance(action, argparse._StoreConstAction) else "option"
|
||||
)
|
||||
|
||||
# extract parameter name (use the first option string, typically the long form)
|
||||
param_name = action.option_strings[0] if action.option_strings else action.dest
|
||||
if param_name.startswith("--"):
|
||||
param_name = param_name[2:]
|
||||
if param_name.startswith("-"):
|
||||
param_name = param_name[1:]
|
||||
|
||||
param_name = param_name.replace("-", "_")
|
||||
|
||||
if skip_names and param_name in skip_names:
|
||||
continue
|
||||
|
||||
# gather other details
|
||||
default_value = action.default
|
||||
if param_name in overwrite_default_values:
|
||||
default_value = overwrite_default_values[param_name]
|
||||
arg_type = (
|
||||
action.type if not callable(action.type) else str(action.type.__name__)
|
||||
)
|
||||
description = action.help
|
||||
|
||||
# determine if the argument is required
|
||||
required = action.required
|
||||
|
||||
# extract valid values for choices, if provided
|
||||
valid_values = action.choices if action.choices is not None else None
|
||||
|
||||
# set ext_metadata as an empty dict for now, can be updated later if needed
|
||||
ext_metadata = {}
|
||||
|
||||
descriptions.append(
|
||||
ParameterDescription(
|
||||
param_class=param_class,
|
||||
param_name=param_name,
|
||||
param_type=arg_type,
|
||||
default_value=default_value,
|
||||
description=description,
|
||||
required=required,
|
||||
valid_values=valid_values,
|
||||
ext_metadata=ext_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
return descriptions
|
||||
|
||||
|
||||
class _SimpleArgParser:
|
||||
def __init__(self, *args):
|
||||
self.params = {arg.replace("_", "-"): None for arg in args}
|
||||
|
0
pilot/utils/tests/__init__.py
Normal file
0
pilot/utils/tests/__init__.py
Normal file
81
pilot/utils/tests/test_parameter_utils.py
Normal file
81
pilot/utils/tests/test_parameter_utils.py
Normal file
@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from pilot.utils.parameter_utils import _extract_parameter_details
|
||||
|
||||
|
||||
def create_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
return parser
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"argument, expected_param_name, default_value, param_type, expected_param_type, description",
|
||||
[
|
||||
("--option", "option", "value", str, "str", "An option argument"),
|
||||
("-option", "option", "value", str, "str", "An option argument"),
|
||||
("--num-gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
|
||||
("--num_gpu", "num_gpu", 1, int, "int", "Number of GPUS"),
|
||||
],
|
||||
)
|
||||
def test_extract_parameter_details_option_argument(
|
||||
argument,
|
||||
expected_param_name,
|
||||
default_value,
|
||||
param_type,
|
||||
expected_param_type,
|
||||
description,
|
||||
):
|
||||
parser = create_parser()
|
||||
parser.add_argument(
|
||||
argument, default=default_value, type=param_type, help=description
|
||||
)
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == expected_param_name
|
||||
assert desc.param_type == expected_param_type
|
||||
assert desc.default_value == default_value
|
||||
assert desc.description == description
|
||||
assert desc.required == False
|
||||
assert desc.valid_values is None
|
||||
|
||||
|
||||
def test_extract_parameter_details_flag_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument("--flag", action="store_true", help="A flag argument")
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "flag"
|
||||
assert desc.description == "A flag argument"
|
||||
assert desc.required == False
|
||||
|
||||
|
||||
def test_extract_parameter_details_choice_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument("--choice", choices=["A", "B", "C"], help="A choice argument")
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "choice"
|
||||
assert desc.valid_values == ["A", "B", "C"]
|
||||
|
||||
|
||||
def test_extract_parameter_details_required_argument():
|
||||
parser = create_parser()
|
||||
parser.add_argument(
|
||||
"--required", required=True, type=int, help="A required argument"
|
||||
)
|
||||
descriptions = _extract_parameter_details(parser)
|
||||
|
||||
assert len(descriptions) == 1
|
||||
desc = descriptions[0]
|
||||
|
||||
assert desc.param_name == "required"
|
||||
assert desc.required == True
|
@ -9,4 +9,6 @@ pytest-mock
|
||||
pytest-recording
|
||||
pytesseract==0.3.10
|
||||
# python code format
|
||||
black
|
||||
black
|
||||
# for git hooks
|
||||
pre-commmit
|
17
setup.py
17
setup.py
@ -203,9 +203,9 @@ def get_cuda_version() -> str:
|
||||
|
||||
|
||||
def torch_requires(
|
||||
torch_version: str = "2.0.0",
|
||||
torchvision_version: str = "0.15.1",
|
||||
torchaudio_version: str = "2.0.1",
|
||||
torch_version: str = "2.0.1",
|
||||
torchvision_version: str = "0.15.2",
|
||||
torchaudio_version: str = "2.0.2",
|
||||
):
|
||||
torch_pkgs = [
|
||||
f"torch=={torch_version}",
|
||||
@ -298,6 +298,7 @@ def core_requires():
|
||||
]
|
||||
|
||||
setup_spec.extras["framework"] = [
|
||||
"fschat",
|
||||
"coloredlogs",
|
||||
"httpx",
|
||||
"sqlparse==0.4.4",
|
||||
@ -396,12 +397,19 @@ def gpt4all_requires():
|
||||
setup_spec.extras["gpt4all"] = ["gpt4all"]
|
||||
|
||||
|
||||
def vllm_requires():
|
||||
"""
|
||||
pip install "db-gpt[vllm]"
|
||||
"""
|
||||
setup_spec.extras["vllm"] = ["vllm"]
|
||||
|
||||
|
||||
def default_requires():
|
||||
"""
|
||||
pip install "db-gpt[default]"
|
||||
"""
|
||||
setup_spec.extras["default"] = [
|
||||
"tokenizers==0.13.2",
|
||||
"tokenizers==0.13.3",
|
||||
"accelerate>=0.20.3",
|
||||
"sentence-transformers",
|
||||
"protobuf==3.20.3",
|
||||
@ -435,6 +443,7 @@ all_vector_store_requires()
|
||||
all_datasource_requires()
|
||||
openai_requires()
|
||||
gpt4all_requires()
|
||||
vllm_requires()
|
||||
|
||||
# must be last
|
||||
default_requires()
|
||||
|
Loading…
Reference in New Issue
Block a user