feat: Multi-model support with proxyllm and add more command-cli

This commit is contained in:
FangYin Cheng 2023-09-05 11:26:24 +08:00
parent b8f09df45e
commit b6a4fd8a62
27 changed files with 668 additions and 130 deletions

View File

@ -18,7 +18,7 @@ services:
networks:
- dbgptnet
webserver:
image: eosphorosai/dbgpt-allinone:latest
image: eosphorosai/dbgpt:latest
command: python3 pilot/server/dbgpt_server.py
environment:
- LOCAL_DB_HOST=db
@ -46,7 +46,6 @@ services:
reservations:
devices:
- driver: nvidia
device_ids: ['0']
capabilities: [gpu]
volumes:
dbgpt-myql-db:

View File

@ -55,6 +55,7 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
&& sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \
fi;)
ENV PYTHONPATH "/app:$PYTHONPATH"
EXPOSE 5000
CMD ["python3", "pilot/server/dbgpt_server.py"]

View File

@ -0,0 +1,68 @@
version: '3.10'
services:
controller:
image: eosphorosai/dbgpt:latest
command: dbgpt start controller
networks:
- dbgptnet
worker:
image: eosphorosai/dbgpt:latest
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
environment:
- DBGPT_LOG_LEVEL=DEBUG
depends_on:
- controller
volumes:
- /data:/data
# Please modify it to your own model directory
- /data/models:/app/models
networks:
- dbgptnet
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
webserver:
image: eosphorosai/dbgpt:latest
command: dbgpt start webserver --light
environment:
- DBGPT_LOG_LEVEL=DEBUG
- LOCAL_DB_PATH=data/default_sqlite.db
- LOCAL_DB_TYPE=sqlite
- ALLOWLISTED_PLUGINS=db_dashboard
- LLM_MODEL=vicuna-13b-v1.5
- MODEL_SERVER=http://controller:8000
depends_on:
- controller
- worker
volumes:
- /data:/data
# Please modify it to your own model directory
- /data/models:/app/models
- dbgpt-data:/app/pilot/data
- dbgpt-message:/app/pilot/message
# env_file:
# - .env.template
ports:
- 5000:5000/tcp
# webserver may be failed, it must wait all sqls in /docker-entrypoint-initdb.d execute finish.
restart: unless-stopped
networks:
- dbgptnet
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
volumes:
dbgpt-myql-db:
dbgpt-data:
dbgpt-message:
networks:
dbgptnet:
driver: bridge
name: dbgptnet

View File

@ -41,6 +41,10 @@ class Config(metaclass=Singleton):
# This is a proxy server, just for test_py. we will remove this later.
self.proxy_api_key = os.getenv("PROXY_API_KEY")
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
# In order to be compatible with the new and old model parameter design
if self.bard_proxy_api_key:
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")

View File

@ -70,7 +70,7 @@ LLM_MODEL_CONFIG = {
"claude_proxyllm": "claude_proxyllm",
"wenxin_proxyllm": "wenxin_proxyllm",
"tongyi_proxyllm": "tongyi_proxyllm",
"gpt4_proxyllm": "gpt4_proxyllm",
"zhipu_proxyllm": "zhipu_proxyllm",
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),

View File

@ -4,7 +4,7 @@
import os
import re
from pathlib import Path
from typing import List, Tuple
from typing import List, Tuple, Callable, Type
from functools import cache
from transformers import (
AutoModel,
@ -12,7 +12,11 @@ from transformers import (
AutoTokenizer,
LlamaTokenizer,
)
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from pilot.configs.model_config import get_device
from pilot.configs.config import Config
from pilot.logs import logger
@ -26,6 +30,7 @@ class ModelType:
HF = "huggingface"
LLAMA_CPP = "llama.cpp"
PROXY = "proxy"
# TODO, support more model type
@ -43,6 +48,8 @@ class BaseLLMAdaper:
model_type = model_type if model_type else self.model_type()
if model_type == ModelType.LLAMA_CPP:
return LlamaCppModelParameters
elif model_type == ModelType.PROXY:
return ProxyModelParameters
return ModelParameters
def match(self, model_path: str):
@ -76,7 +83,7 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
return adapter
for adapter in llm_model_adapters:
if adapter.match(model_path):
if model_path and adapter.match(model_path):
logger.info(
f"Found llm model adapter with model path: {model_path}, {adapter}"
)
@ -87,6 +94,20 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
)
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
from pilot.utils.parameter_utils import _SimpleArgParser
pre_args = _SimpleArgParser("model_name", "model_path")
pre_args.parse()
model_name = pre_args.get("model_name")
model_path = pre_args.get("model_path")
if model_name is None:
return None
llm_adapter = get_llm_model_adapter(model_name, model_path)
param_class = llm_adapter.model_param_class()
return [param_class]
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
@ -281,6 +302,9 @@ class GPT4AllAdapter(BaseLLMAdaper):
class ProxyllmAdapter(BaseLLMAdaper):
"""The model adapter for local proxy"""
def model_type(self) -> str:
return ModelType.PROXY
def match(self, model_path: str):
return "proxyllm" in model_path

View File

@ -1,8 +1,11 @@
import click
import functools
import logging
import os
from typing import Callable, List, Type
from pilot.model.controller.registry import ModelRegistryClient
from pilot.configs.model_config import LOGDIR
from pilot.model.base import WorkerApplyType
from pilot.model.parameter import (
ModelControllerParameters,
@ -11,6 +14,8 @@ from pilot.model.parameter import (
)
from pilot.utils import get_or_create_event_loop
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
@ -157,46 +162,81 @@ def worker_apply(
print(res)
def add_stop_server_options(func):
@click.option(
"--port",
type=int,
default=None,
required=False,
help=("The port to stop"),
)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
@click.command(name="controller")
@EnvArgumentParser.create_click_option(ModelControllerParameters)
def start_model_controller(**kwargs):
"""Start model controller"""
from pilot.model.controller.controller import run_model_controller
run_model_controller()
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
_run_current_with_daemon("ModelController", log_file)
else:
from pilot.model.controller.controller import run_model_controller
run_model_controller()
@click.command(name="controller")
def stop_model_controller(**kwargs):
@add_stop_server_options
def stop_model_controller(port: int):
"""Start model controller"""
raise NotImplementedError
# Command fragments to check against running processes
_stop_service("controller", "ModelController", port=port)
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
from pilot.model.adapter import _dynamic_model_parser
param_class = _dynamic_model_parser()
fix_class = [ModelWorkerParameters]
if not param_class:
param_class = [ModelParameters]
fix_class += param_class
return fix_class
@click.command(name="worker")
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
@EnvArgumentParser.create_click_option(
ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory
)
def start_model_worker(**kwargs):
"""Start model worker"""
from pilot.model.worker.manager import run_worker_manager
if kwargs["daemon"]:
port = kwargs["port"]
model_type = kwargs.get("worker_type") or "llm"
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
_run_current_with_daemon("ModelWorker", log_file)
else:
from pilot.model.worker.manager import run_worker_manager
run_worker_manager()
run_worker_manager()
@click.command(name="worker")
def stop_model_worker(**kwargs):
@add_stop_server_options
def stop_model_worker(port: int):
"""Stop model worker"""
raise NotImplementedError
@click.command(name="webserver")
def start_webserver(**kwargs):
"""Start webserver(dbgpt_server.py)"""
raise NotImplementedError
@click.command(name="webserver")
def stop_webserver(**kwargs):
"""Stop webserver(dbgpt_server.py)"""
raise NotImplementedError
name = "ModelWorker"
if port:
name = f"{name}-{port}"
_stop_service("worker", name, port=port)
@click.command(name="apiserver")
@ -205,7 +245,7 @@ def start_apiserver(**kwargs):
raise NotImplementedError
@click.command(name="controller")
@click.command(name="apiserver")
def stop_apiserver(**kwargs):
"""Start apiserver(TODO)"""
raise NotImplementedError

View File

@ -9,7 +9,9 @@ from pilot.model.proxy.llms.bard import bard_generate_stream
from pilot.model.proxy.llms.claude import claude_generate_stream
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
# from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream
CFG = Config()
@ -20,9 +22,10 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
"chatgpt_proxyllm": chatgpt_generate_stream,
"bard_proxyllm": bard_generate_stream,
"claude_proxyllm": claude_generate_stream,
"gpt4_proxyllm": gpt4_generate_stream,
# "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream
"wenxin_proxyllm": wenxin_generate_stream,
"tongyi_proxyllm": tongyi_generate_stream,
"zhipu_proxyllm": zhipu_generate_stream,
}
default_error_message = f"{CFG.LLM_MODEL} LLM is not supported"

View File

@ -8,6 +8,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
from pilot.model.parameter import (
ModelParameters,
LlamaCppModelParameters,
ProxyModelParameters,
)
from pilot.utils import get_gpu_memory
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
@ -114,11 +115,12 @@ class ModelLoader:
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = llm_adapter.model_type()
self.prompt_template = model_params.prompt_template
logger.info(f"model_params:\n{model_params}")
if model_type == ModelType.HF:
return huggingface_loader(llm_adapter, model_params)
elif model_type == ModelType.LLAMA_CPP:
return llamacpp_loader(llm_adapter, model_params)
elif model_type == ModelType.PROXY:
return proxyllm_loader(llm_adapter, model_params)
else:
raise Exception(f"Unkown model type {model_type}")
@ -346,3 +348,10 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
model_path = model_params.model_path
model, tokenizer = LlamaCppModel.from_pretrained(model_path, model_params)
return model, tokenizer
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
from pilot.model.proxy.llms.proxy_model import ProxyModel
model = ProxyModel(model_params)
return model, model

View File

@ -27,6 +27,9 @@ class ModelControllerParameters(BaseParameters):
port: Optional[int] = field(
default=8000, metadata={"help": "Model Controller deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Controller in background"}
)
@dataclass
@ -50,6 +53,9 @@ class ModelWorkerParameters(BaseParameters):
port: Optional[int] = field(
default=8000, metadata={"help": "Model worker deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Model Worker in background"}
)
limit_model_concurrency: Optional[int] = field(
default=5, metadata={"help": "Model concurrency limit"}
)
@ -109,9 +115,13 @@ class EmbeddingModelParameters(BaseParameters):
@dataclass
class ModelParameters(BaseParameters):
class BaseModelParameters(BaseParameters):
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
@dataclass
class ModelParameters(BaseModelParameters):
device: Optional[str] = field(
default=None,
metadata={
@ -120,7 +130,10 @@ class ModelParameters(BaseParameters):
)
model_type: Optional[str] = field(
default="huggingface",
metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"},
metadata={
"help": "Model type, huggingface, llama.cpp and proxy",
"tags": "fixed",
},
)
prompt_template: Optional[str] = field(
default=None,
@ -221,3 +234,43 @@ class LlamaCppModelParameters(ModelParameters):
"help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured."
},
)
@dataclass
class ProxyModelParameters(BaseModelParameters):
proxy_server_url: str = field(
metadata={
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
},
)
proxy_api_key: str = field(
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
)
proxyllm_backend: Optional[str] = field(
default=None,
metadata={
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
},
)
model_type: Optional[str] = field(
default="proxy",
metadata={
"help": "Model type, huggingface, llama.cpp and proxy",
"tags": "fixed",
},
)
device: Optional[str] = field(
default=None,
metadata={
"help": "Device to run model. If None, the device is automatically determined"
},
)
prompt_template: Optional[str] = field(
default=None,
metadata={
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
},
)
max_context_size: Optional[int] = field(
default=4096, metadata={"help": "Maximum context size"}
)

View File

@ -1,13 +1,19 @@
import bardapi
import requests
from typing import List
from pilot.configs.config import Config
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CFG = Config()
from pilot.model.proxy.llms.proxy_model import ProxyModel
def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
def bard_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
history = []
messages: List[ModelMessage] = params["messages"]
for message in messages:
@ -35,18 +41,18 @@ def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
if msg.get("content"):
msgs.append(msg["content"])
if CFG.proxy_server_url is not None:
if proxy_server_url is not None:
headers = {"Content-Type": "application/json"}
payloads = {"input": "\n".join(msgs)}
response = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=False
proxy_server_url, headers=headers, json=payloads, stream=False
)
if response.ok:
yield response.text
else:
yield f"bard proxy url request failed!, response = {str(response)}"
else:
response = bardapi.core.Bard(CFG.bard_proxy_api_key).get_answer("\n".join(msgs))
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
if response is not None and response.get("content") is not None:
yield str(response["content"])

View File

@ -4,18 +4,27 @@
import json
import requests
from typing import List
from pilot.configs.config import Config
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
CFG = Config()
from pilot.model.proxy.llms.proxy_model import ProxyModel
def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048):
def chatgpt_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
history = []
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "gpt-3.5-turbo"
headers = {
"Authorization": "Bearer " + CFG.proxy_api_key,
"Token": CFG.proxy_api_key,
"Authorization": "Bearer " + proxy_api_key,
"Token": proxy_api_key,
}
messages: List[ModelMessage] = params["messages"]
@ -42,16 +51,16 @@ def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048):
history.append(last_user_input)
payloads = {
"model": "gpt-3.5-turbo", # just for test, remove this later
"model": proxyllm_backend, # just for test, remove this later
"messages": history,
"temperature": params.get("temperature"),
"max_tokens": params.get("max_new_tokens"),
"stream": True,
}
res = requests.post(
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
)
res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True)
print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}")
text = ""
for line in res.iter_lines():

View File

@ -1,7 +1,7 @@
from pilot.configs.config import Config
CFG = Config()
from pilot.model.proxy.llms.proxy_model import ProxyModel
def claude_generate_stream(model, tokenizer, params, device, context_len=2048):
def claude_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "claude LLM was not supported!"

View File

@ -1,7 +0,0 @@
from pilot.configs.config import Config
CFG = Config()
def gpt4_generate_stream(model, tokenizer, params, device, context_len=2048):
yield "gpt4 LLM was not supported!"

View File

@ -0,0 +1,9 @@
from pilot.model.parameter import ProxyModelParameters
class ProxyModel:
def __init__(self, model_params: ProxyModelParameters) -> None:
self._model_params = model_params
def get_params(self) -> ProxyModelParameters:
return self._model_params

View File

@ -1,7 +1,7 @@
from pilot.configs.config import Config
CFG = Config()
from pilot.model.proxy.llms.proxy_model import ProxyModel
def tongyi_generate_stream(model, tokenizer, params, device, context_len=2048):
def tongyi_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "tongyi LLM was not supported!"

View File

@ -1,7 +1,7 @@
from pilot.configs.config import Config
CFG = Config()
from pilot.model.proxy.llms.proxy_model import ProxyModel
def wenxin_generate_stream(model, tokenizer, params, device, context_len=2048):
def wenxin_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
yield "wenxin LLM is not supported!"

View File

@ -0,0 +1,18 @@
from pilot.model.proxy.llms.proxy_model import ProxyModel
def zhipu_generate_stream(
model: ProxyModel, tokenizer, params, device, context_len=2048
):
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
model_params = model.get_params()
print(f"Model: {model}, model_params: {model_params}")
proxy_api_key = model_params.proxy_api_key
proxy_server_url = model_params.proxy_server_url
proxyllm_backend = model_params.proxyllm_backend
if not proxyllm_backend:
proxyllm_backend = "chatglm_pro"
# TODO
yield "Zhipu LLM was not supported!"

View File

@ -33,6 +33,9 @@ class DefaultModelWorker(ModelWorker):
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
model_type = self.llm_adapter.model_type()
self.param_cls = self.llm_adapter.model_param_class(model_type)
logger.info(
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
)
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(

View File

@ -309,7 +309,7 @@ class LocalWorkerManager(WorkerManager):
else:
# Apply to all workers
worker_instances = list(itertools.chain(*self.workers.values()))
logger.info(f"Apply to all workers: {worker_instances}")
logger.info(f"Apply to all workers")
return await asyncio.gather(
*(apply_func(worker) for worker in worker_instances)
)

View File

@ -61,8 +61,6 @@ try:
stop_model_controller,
start_model_worker,
stop_model_worker,
start_webserver,
stop_webserver,
start_apiserver,
stop_apiserver,
)
@ -70,17 +68,24 @@ try:
add_command_alias(model_cli_group, name="model", parent_group=cli)
add_command_alias(start_model_controller, name="controller", parent_group=start)
add_command_alias(start_model_worker, name="worker", parent_group=start)
add_command_alias(start_webserver, name="webserver", parent_group=start)
add_command_alias(start_apiserver, name="apiserver", parent_group=start)
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
except ImportError as e:
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
try:
from pilot.server._cli import start_webserver, stop_webserver
add_command_alias(start_webserver, name="webserver", parent_group=start)
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
except ImportError as e:
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")
try:
from pilot.server.knowledge._cli.knowledge_cli import knowledge_cli_group

32
pilot/server/_cli.py Normal file
View File

@ -0,0 +1,32 @@
import click
import os
from pilot.server.base import WebWerverParameters
from pilot.configs.model_config import LOGDIR
from pilot.utils.parameter_utils import EnvArgumentParser
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
@click.command(name="webserver")
@EnvArgumentParser.create_click_option(WebWerverParameters)
def start_webserver(**kwargs):
"""Start webserver(dbgpt_server.py)"""
if kwargs["daemon"]:
log_file = os.path.join(LOGDIR, "webserver_uvicorn.log")
_run_current_with_daemon("WebServer", log_file)
else:
from pilot.server.dbgpt_server import run_webserver
run_webserver(WebWerverParameters(**kwargs))
@click.command(name="webserver")
@click.option(
"--port",
type=int,
default=None,
required=False,
help=("The port to stop"),
)
def stop_webserver(port: int):
"""Stop webserver(dbgpt_server.py)"""
_stop_service("webserver", "WebServer", port=port)

View File

@ -2,13 +2,12 @@ import signal
import os
import threading
import sys
from typing import Optional
from dataclasses import dataclass, field
from pilot.summary.db_summary_client import DBSummaryClient
from pilot.commands.command_mange import CommandRegistry
from pilot.configs.config import Config
from pilot.utils.parameter_utils import BaseParameters
from pilot.common.plugins import scan_plugins
from pilot.connections.manages.connection_manager import ConnectManager
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -20,12 +19,18 @@ def signal_handler(sig, frame):
def async_db_summery():
from pilot.summary.db_summary_client import DBSummaryClient
client = DBSummaryClient()
thread = threading.Thread(target=client.init_db_summary)
thread.start()
def server_init(args):
from pilot.commands.command_mange import CommandRegistry
from pilot.connections.manages.connection_manager import ConnectManager
from pilot.common.plugins import scan_plugins
# logger.info(f"args: {args}")
# init config
@ -63,3 +68,38 @@ def server_init(args):
for command in command_disply_commands:
command_disply_registry.import_commands(command)
cfg.command_disply = command_disply_registry
@dataclass
class WebWerverParameters(BaseParameters):
host: Optional[str] = field(
default="0.0.0.0", metadata={"help": "Webserver deploy host"}
)
port: Optional[int] = field(
default=5000, metadata={"help": "Webserver deploy port"}
)
daemon: Optional[bool] = field(
default=False, metadata={"help": "Run Webserver in background"}
)
share: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
},
)
log_level: Optional[str] = field(
default="INFO",
metadata={
"help": "Logging level",
"valid_values": [
"FATAL",
"ERROR",
"WARNING",
"WARNING",
"INFO",
"DEBUG",
"NOTSET",
],
},
)
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})

View File

@ -1,7 +1,7 @@
import os
import argparse
import sys
import logging
from typing import List
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(ROOT_PATH)
@ -9,7 +9,7 @@ import signal
from pilot.configs.config import Config
from pilot.configs.model_config import LLM_MODEL_CONFIG
from pilot.server.base import server_init
from pilot.server.base import server_init, WebWerverParameters
from fastapi.staticfiles import StaticFiles
from fastapi import FastAPI, applications
@ -24,7 +24,7 @@ from pilot.openapi.base import validation_exception_handler
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
from pilot.model.worker.manager import initialize_worker_manager_in_client
from pilot.utils.utils import setup_logging
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
static_file_path = os.path.join(os.getcwd(), "server/static")
@ -84,33 +84,24 @@ def mount_static_files(app):
app.add_exception_handler(RequestValidationError, validation_exception_handler)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
)
# old version server config
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=5000)
parser.add_argument("--concurrency-count", type=int, default=10)
parser.add_argument("--share", default=False, action="store_true")
parser.add_argument("--log-level", type=str, default=None)
parser.add_argument(
"-light",
"--light",
default=False,
action="store_true",
help="enable light mode",
)
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
"""Initialize app
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
"""
if not param:
from pilot.utils.parameter_utils import EnvArgumentParser
# init server config
args = parser.parse_args()
setup_logging(logging_level=args.log_level)
server_init(args)
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
WebWerverParameters
)
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
setup_logging(logging_level=param.log_level)
server_init(param)
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
if not args.light:
if not param.light:
print("Model Unified Deployment Mode!")
initialize_worker_manager_in_client(
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
@ -129,7 +120,25 @@ if __name__ == "__main__":
CFG.SERVER_LIGHT_MODE = True
mount_static_files(app)
return param
def run_uvicorn(param: WebWerverParameters):
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
uvicorn.run(
app,
host=param.host,
port=param.port,
log_level=logging_str_to_uvicorn_level(param.log_level),
)
signal.signal(signal.SIGINT, signal_handler())
def run_webserver(param: WebWerverParameters = None):
param = initialize_app(param)
run_uvicorn(param)
if __name__ == "__main__":
run_webserver()

View File

@ -0,0 +1,91 @@
import sys
import os
import subprocess
from typing import List, Dict
import psutil
import platform
def _run_current_with_daemon(name: str, log_file: str):
# Get all arguments except for --daemon
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
daemon_cmd = [sys.executable] + args
daemon_cmd = " ".join(daemon_cmd)
daemon_cmd += f" > {log_file} 2>&1"
# Check the platform and set the appropriate flags or functions
if platform.system() == "Windows":
process = subprocess.Popen(
daemon_cmd,
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=True,
)
else: # macOS, Linux, and other Unix-like systems
process = subprocess.Popen(
daemon_cmd,
preexec_fn=os.setsid,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
shell=True,
)
print(f"Started {name} in background with pid: {process.pid}")
def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
try:
import gunicorn
except ImportError as e:
raise ValueError(
"Could not import python package: gunicorn"
"Daemon mode need install gunicorn, please install `pip install gunicorn`"
) from e
from pilot.utils.parameter_utils import EnvArgumentParser
env_to_app = {}
env_to_app.update(os.environ)
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
env_to_app.update(app_env)
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
if platform.system() == "Windows":
raise Exception("Not support on windows")
else: # macOS, Linux, and other Unix-like systems
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
print(f"Started {app} with gunicorn in background with pid: {process.pid}")
def _stop_service(
key: str, fullname: str, service_keys: List[str] = None, port: int = None
):
if not service_keys:
service_keys = [sys.argv[0], "start", key]
not_found = True
for process in psutil.process_iter(attrs=["pid", "connections", "cmdline"]):
try:
cmdline = " ".join(process.info["cmdline"])
# Check if all key fragments are in the cmdline
if all(fragment in cmdline for fragment in service_keys):
if port:
for conn in process.info["connections"]:
if (
conn.status == psutil.CONN_LISTEN
and conn.laddr.port == port
):
psutil.Process(process.info["pid"]).terminate()
print(
f"Terminated the {fullname} with PID: {process.info['pid']} listening on port: {port}"
)
not_found = False
else:
psutil.Process(process.info["pid"]).terminate()
print(f"Terminated the {fullname} with PID: {process.info['pid']}")
not_found = False
except (psutil.NoSuchProcess, psutil.AccessDenied):
continue
if not_found:
print(f"{fullname} process not found.")

View File

@ -1,7 +1,8 @@
import argparse
import os
from dataclasses import dataclass, fields, MISSING
from typing import Any, List, Optional, Type, Union, Callable
from typing import Any, List, Optional, Type, Union, Callable, Dict
from collections import OrderedDict
@dataclass
@ -60,7 +61,7 @@ class BaseParameters:
f"\n\n=========================== {class_name} ===========================\n"
]
for field_info in fields(self):
value = getattr(self, field_info.name)
value = _get_simple_privacy_field_value(self, field_info)
parameters.append(f"{field_info.name}: {value}")
parameters.append(
"\n======================================================================\n\n"
@ -68,6 +69,55 @@ class BaseParameters:
return "\n".join(parameters)
def _get_simple_privacy_field_value(obj, field_info):
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
This function reads the metadata of a field to check if it's tagged with 'privacy'.
If the 'privacy' tag is present, then it modifies the value based on its type
for privacy concerns:
- int: returns -999
- float: returns -999.0
- bool: returns False
- str: if length > 5, masks the middle part and returns first and last char;
otherwise, returns "******"
Parameters:
- obj: The dataclass instance.
- field_info: A Field object that contains information about the dataclass field.
Returns:
The original or modified value of the field based on the privacy rules.
Example usage:
@dataclass
class Person:
name: str
age: int
ssn: str = field(metadata={"tags": "privacy"})
p = Person("Alice", 30, "123-45-6789")
print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9
"""
tags = field_info.metadata.get("tags")
tags = [] if not tags else tags.split(",")
is_privacy = False
if tags and "privacy" in tags:
is_privacy = True
value = getattr(obj, field_info.name)
if not is_privacy or not value:
return value
field_type = EnvArgumentParser._get_argparse_type(field_info.type)
if field_type is int:
return -999
if field_type is float:
return -999.0
if field_type is bool:
return False
# str
if len(value) > 5:
return value[0] + "******" + value[-1]
return "******"
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
"""Get the value from the environment variable, ignoring the case of the key"""
if env_prefix:
@ -115,24 +165,11 @@ class EnvArgumentParser:
if not env_var_value:
env_var_value = kwargs.get(field.name)
print(f"env_var_value: {env_var_value} for {field.name}")
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
if env_var_value:
argument_kwargs["default"] = env_var_value
argument_kwargs["required"] = False
parser.add_argument(f"--{field.name}", **argument_kwargs)
EnvArgumentParser._build_single_argparse_option(
parser, field, env_var_value
)
# Parse the command-line arguments
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
@ -148,7 +185,7 @@ class EnvArgumentParser:
return dataclass_type(**kwargs)
@staticmethod
def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
for field in fields(dataclass_type):
help_text = field.metadata.get("help", "")
@ -173,11 +210,6 @@ class EnvArgumentParser:
import functools
from collections import OrderedDict
# TODO dynamic configuration
# pre_args = _SimpleArgParser('model_name', 'model_path')
# pre_args.parse()
# print(pre_args)
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
@ -225,6 +257,48 @@ class EnvArgumentParser:
return decorator
@staticmethod
def create_argparse_option(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
) -> argparse.ArgumentParser:
combined_fields = _merge_dataclass_types(
*dataclass_types, _dynamic_factory=_dynamic_factory
)
parser = argparse.ArgumentParser()
for _, field in reversed(combined_fields.items()):
EnvArgumentParser._build_single_argparse_option(parser, field)
return parser
@staticmethod
def _build_single_argparse_option(
parser: argparse.ArgumentParser, field, default_value=None
):
# Add a command-line argument for this field
help_text = field.metadata.get("help", "")
valid_values = field.metadata.get("valid_values", None)
short_name = field.metadata.get("short", None)
argument_kwargs = {
"type": EnvArgumentParser._get_argparse_type(field.type),
"help": help_text,
"choices": valid_values,
"required": EnvArgumentParser._is_require_type(field.type),
}
if field.default != MISSING:
argument_kwargs["default"] = field.default
argument_kwargs["required"] = False
if default_value:
argument_kwargs["default"] = default_value
argument_kwargs["required"] = False
if field.type is bool or field.type == Optional[bool]:
argument_kwargs["action"] = "store_true"
del argument_kwargs["type"]
del argument_kwargs["choices"]
names = []
if short_name:
names.append(f"-{short_name}")
names.append(f"--{field.name}")
parser.add_argument(*names, **argument_kwargs)
@staticmethod
def _get_argparse_type(field_type: Type) -> Type:
# Return the appropriate type for argparse to use based on the field type
@ -255,6 +329,42 @@ class EnvArgumentParser:
def _is_require_type(field_type: Type) -> str:
return field_type not in [Optional[int], Optional[float], Optional[bool]]
@staticmethod
def _kwargs_to_env_key_value(
kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__"
) -> Dict[str, str]:
return {prefix + k: str(v) for k, v in kwargs.items()}
@staticmethod
def _read_env_key_value(
prefix: str = "__dbgpt_gunicorn__env_prefix__",
) -> List[str]:
env_args = []
for key, value in os.environ.items():
if key.startswith(prefix):
arg_key = "--" + key.replace(prefix, "")
if value.lower() in ["true", "1"]:
# Flag args
env_args.append(arg_key)
elif not value.lower() in ["false", "0"]:
env_args.extend([arg_key, value])
return env_args
def _merge_dataclass_types(
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
) -> OrderedDict:
combined_fields = OrderedDict()
if _dynamic_factory:
_types = _dynamic_factory()
if _types:
dataclass_types = list(_types)
for dataclass_type in dataclass_types:
for field in fields(dataclass_type):
if field.name not in combined_fields:
combined_fields[field.name] = field
return combined_fields
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
descriptions = []
@ -297,7 +407,7 @@ class _SimpleArgParser:
self.params[prev_arg] = None
def _get_param(self, key):
return self.params.get(key.replace("_", "-"), None)
return self.params.get(key.replace("_", "-")) or self.params.get(key)
def __getattr__(self, item):
return self._get_param(item)

View File

@ -161,3 +161,15 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
raise e
logging.warning("Cant not get running event loop, create new event loop now")
return asyncio.get_event_loop_policy().get_event_loop()
def logging_str_to_uvicorn_level(log_level_str):
level_str_mapping = {
"CRITICAL": "critical",
"ERROR": "error",
"WARNING": "warning",
"INFO": "info",
"DEBUG": "debug",
"NOTSET": "info",
}
return level_str_mapping.get(log_level_str.upper(), "info")