From b6a4fd8a6213c334a7fc322282d334a01269e5e7 Mon Sep 17 00:00:00 2001 From: FangYin Cheng Date: Tue, 5 Sep 2023 11:26:24 +0800 Subject: [PATCH] feat: Multi-model support with proxyllm and add more command-cli --- docker-compose.yml | 3 +- docker/base/Dockerfile | 1 + .../cluster-docker-compose.yml | 68 ++++++++ pilot/configs/config.py | 4 + pilot/configs/model_config.py | 2 +- pilot/model/adapter.py | 30 +++- pilot/model/cli.py | 82 ++++++--- pilot/model/llm_out/proxy_llm.py | 7 +- pilot/model/loader.py | 11 +- pilot/model/parameter.py | 57 +++++- pilot/model/proxy/llms/bard.py | 20 ++- pilot/model/proxy/llms/chatgpt.py | 29 ++-- pilot/model/proxy/llms/claude.py | 8 +- pilot/model/proxy/llms/gpt4.py | 7 - pilot/model/proxy/llms/proxy_model.py | 9 + pilot/model/proxy/llms/tongyi.py | 8 +- pilot/model/proxy/llms/wenxin.py | 8 +- pilot/model/proxy/llms/zhipu.py | 18 ++ pilot/model/worker/default_worker.py | 3 + pilot/model/worker/manager.py | 2 +- pilot/scripts/cli_scripts.py | 13 +- pilot/server/_cli.py | 32 ++++ pilot/server/base.py | 48 +++++- pilot/server/dbgpt_server.py | 63 ++++--- pilot/utils/command_utils.py | 91 ++++++++++ pilot/utils/parameter_utils.py | 162 +++++++++++++++--- pilot/utils/utils.py | 12 ++ 27 files changed, 668 insertions(+), 130 deletions(-) create mode 100644 docker/compose_examples/cluster-docker-compose.yml delete mode 100644 pilot/model/proxy/llms/gpt4.py create mode 100644 pilot/model/proxy/llms/proxy_model.py create mode 100644 pilot/model/proxy/llms/zhipu.py create mode 100644 pilot/server/_cli.py create mode 100644 pilot/utils/command_utils.py diff --git a/docker-compose.yml b/docker-compose.yml index 84d19525f..24856bff2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -18,7 +18,7 @@ services: networks: - dbgptnet webserver: - image: eosphorosai/dbgpt-allinone:latest + image: eosphorosai/dbgpt:latest command: python3 pilot/server/dbgpt_server.py environment: - LOCAL_DB_HOST=db @@ -46,7 +46,6 @@ services: reservations: devices: - driver: nvidia - device_ids: ['0'] capabilities: [gpu] volumes: dbgpt-myql-db: diff --git a/docker/base/Dockerfile b/docker/base/Dockerfile index e273b8bd9..3abd04909 100644 --- a/docker/base/Dockerfile +++ b/docker/base/Dockerfile @@ -55,6 +55,7 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \ && sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \ fi;) +ENV PYTHONPATH "/app:$PYTHONPATH" EXPOSE 5000 CMD ["python3", "pilot/server/dbgpt_server.py"] \ No newline at end of file diff --git a/docker/compose_examples/cluster-docker-compose.yml b/docker/compose_examples/cluster-docker-compose.yml new file mode 100644 index 000000000..c4928a53a --- /dev/null +++ b/docker/compose_examples/cluster-docker-compose.yml @@ -0,0 +1,68 @@ +version: '3.10' + +services: + controller: + image: eosphorosai/dbgpt:latest + command: dbgpt start controller + networks: + - dbgptnet + worker: + image: eosphorosai/dbgpt:latest + command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000 + environment: + - DBGPT_LOG_LEVEL=DEBUG + depends_on: + - controller + volumes: + - /data:/data + # Please modify it to your own model directory + - /data/models:/app/models + networks: + - dbgptnet + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] + webserver: + image: eosphorosai/dbgpt:latest + command: dbgpt start webserver --light + environment: + - DBGPT_LOG_LEVEL=DEBUG + - LOCAL_DB_PATH=data/default_sqlite.db + - LOCAL_DB_TYPE=sqlite + - ALLOWLISTED_PLUGINS=db_dashboard + - LLM_MODEL=vicuna-13b-v1.5 + - MODEL_SERVER=http://controller:8000 + depends_on: + - controller + - worker + volumes: + - /data:/data + # Please modify it to your own model directory + - /data/models:/app/models + - dbgpt-data:/app/pilot/data + - dbgpt-message:/app/pilot/message + # env_file: + # - .env.template + ports: + - 5000:5000/tcp + # webserver may be failed, it must wait all sqls in /docker-entrypoint-initdb.d execute finish. + restart: unless-stopped + networks: + - dbgptnet + deploy: + resources: + reservations: + devices: + - driver: nvidia + capabilities: [gpu] +volumes: + dbgpt-myql-db: + dbgpt-data: + dbgpt-message: +networks: + dbgptnet: + driver: bridge + name: dbgptnet \ No newline at end of file diff --git a/pilot/configs/config.py b/pilot/configs/config.py index 86af93561..8bfaed757 100644 --- a/pilot/configs/config.py +++ b/pilot/configs/config.py @@ -41,6 +41,10 @@ class Config(metaclass=Singleton): # This is a proxy server, just for test_py. we will remove this later. self.proxy_api_key = os.getenv("PROXY_API_KEY") self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY") + # In order to be compatible with the new and old model parameter design + if self.bard_proxy_api_key: + os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key + self.proxy_server_url = os.getenv("PROXY_SERVER_URL") self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY") diff --git a/pilot/configs/model_config.py b/pilot/configs/model_config.py index 517705edc..0b2ed520a 100644 --- a/pilot/configs/model_config.py +++ b/pilot/configs/model_config.py @@ -70,7 +70,7 @@ LLM_MODEL_CONFIG = { "claude_proxyllm": "claude_proxyllm", "wenxin_proxyllm": "wenxin_proxyllm", "tongyi_proxyllm": "tongyi_proxyllm", - "gpt4_proxyllm": "gpt4_proxyllm", + "zhipu_proxyllm": "zhipu_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), diff --git a/pilot/model/adapter.py b/pilot/model/adapter.py index cf20742dd..44d115820 100644 --- a/pilot/model/adapter.py +++ b/pilot/model/adapter.py @@ -4,7 +4,7 @@ import os import re from pathlib import Path -from typing import List, Tuple +from typing import List, Tuple, Callable, Type from functools import cache from transformers import ( AutoModel, @@ -12,7 +12,11 @@ from transformers import ( AutoTokenizer, LlamaTokenizer, ) -from pilot.model.parameter import ModelParameters, LlamaCppModelParameters +from pilot.model.parameter import ( + ModelParameters, + LlamaCppModelParameters, + ProxyModelParameters, +) from pilot.configs.model_config import get_device from pilot.configs.config import Config from pilot.logs import logger @@ -26,6 +30,7 @@ class ModelType: HF = "huggingface" LLAMA_CPP = "llama.cpp" + PROXY = "proxy" # TODO, support more model type @@ -43,6 +48,8 @@ class BaseLLMAdaper: model_type = model_type if model_type else self.model_type() if model_type == ModelType.LLAMA_CPP: return LlamaCppModelParameters + elif model_type == ModelType.PROXY: + return ProxyModelParameters return ModelParameters def match(self, model_path: str): @@ -76,7 +83,7 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: return adapter for adapter in llm_model_adapters: - if adapter.match(model_path): + if model_path and adapter.match(model_path): logger.info( f"Found llm model adapter with model path: {model_path}, {adapter}" ) @@ -87,6 +94,20 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper: ) +def _dynamic_model_parser() -> Callable[[None], List[Type]]: + from pilot.utils.parameter_utils import _SimpleArgParser + + pre_args = _SimpleArgParser("model_name", "model_path") + pre_args.parse() + model_name = pre_args.get("model_name") + model_path = pre_args.get("model_path") + if model_name is None: + return None + llm_adapter = get_llm_model_adapter(model_name, model_path) + param_class = llm_adapter.model_param_class() + return [param_class] + + # TODO support cpu? for practise we support gpt4all or chatglm-6b-int4? @@ -281,6 +302,9 @@ class GPT4AllAdapter(BaseLLMAdaper): class ProxyllmAdapter(BaseLLMAdaper): """The model adapter for local proxy""" + def model_type(self) -> str: + return ModelType.PROXY + def match(self, model_path: str): return "proxyllm" in model_path diff --git a/pilot/model/cli.py b/pilot/model/cli.py index a56cae6dd..5f8a23def 100644 --- a/pilot/model/cli.py +++ b/pilot/model/cli.py @@ -1,8 +1,11 @@ import click import functools import logging +import os +from typing import Callable, List, Type from pilot.model.controller.registry import ModelRegistryClient +from pilot.configs.model_config import LOGDIR from pilot.model.base import WorkerApplyType from pilot.model.parameter import ( ModelControllerParameters, @@ -11,6 +14,8 @@ from pilot.model.parameter import ( ) from pilot.utils import get_or_create_event_loop from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.utils.command_utils import _run_current_with_daemon, _stop_service + MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000" @@ -157,46 +162,81 @@ def worker_apply( print(res) +def add_stop_server_options(func): + @click.option( + "--port", + type=int, + default=None, + required=False, + help=("The port to stop"), + ) + @functools.wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + @click.command(name="controller") @EnvArgumentParser.create_click_option(ModelControllerParameters) def start_model_controller(**kwargs): """Start model controller""" + from pilot.model.controller.controller import run_model_controller - run_model_controller() + if kwargs["daemon"]: + log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log") + _run_current_with_daemon("ModelController", log_file) + else: + from pilot.model.controller.controller import run_model_controller + + run_model_controller() @click.command(name="controller") -def stop_model_controller(**kwargs): +@add_stop_server_options +def stop_model_controller(port: int): """Start model controller""" - raise NotImplementedError + # Command fragments to check against running processes + _stop_service("controller", "ModelController", port=port) + + +def _model_dynamic_factory() -> Callable[[None], List[Type]]: + from pilot.model.adapter import _dynamic_model_parser + + param_class = _dynamic_model_parser() + fix_class = [ModelWorkerParameters] + if not param_class: + param_class = [ModelParameters] + fix_class += param_class + return fix_class @click.command(name="worker") -@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters) +@EnvArgumentParser.create_click_option( + ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory +) def start_model_worker(**kwargs): """Start model worker""" - from pilot.model.worker.manager import run_worker_manager + if kwargs["daemon"]: + port = kwargs["port"] + model_type = kwargs.get("worker_type") or "llm" + log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log") + _run_current_with_daemon("ModelWorker", log_file) + else: + from pilot.model.worker.manager import run_worker_manager - run_worker_manager() + run_worker_manager() @click.command(name="worker") -def stop_model_worker(**kwargs): +@add_stop_server_options +def stop_model_worker(port: int): """Stop model worker""" - raise NotImplementedError - - -@click.command(name="webserver") -def start_webserver(**kwargs): - """Start webserver(dbgpt_server.py)""" - raise NotImplementedError - - -@click.command(name="webserver") -def stop_webserver(**kwargs): - """Stop webserver(dbgpt_server.py)""" - raise NotImplementedError + name = "ModelWorker" + if port: + name = f"{name}-{port}" + _stop_service("worker", name, port=port) @click.command(name="apiserver") @@ -205,7 +245,7 @@ def start_apiserver(**kwargs): raise NotImplementedError -@click.command(name="controller") +@click.command(name="apiserver") def stop_apiserver(**kwargs): """Start apiserver(TODO)""" raise NotImplementedError diff --git a/pilot/model/llm_out/proxy_llm.py b/pilot/model/llm_out/proxy_llm.py index 3af15d402..c8febba9b 100644 --- a/pilot/model/llm_out/proxy_llm.py +++ b/pilot/model/llm_out/proxy_llm.py @@ -9,7 +9,9 @@ from pilot.model.proxy.llms.bard import bard_generate_stream from pilot.model.proxy.llms.claude import claude_generate_stream from pilot.model.proxy.llms.wenxin import wenxin_generate_stream from pilot.model.proxy.llms.tongyi import tongyi_generate_stream -from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream +from pilot.model.proxy.llms.zhipu import zhipu_generate_stream + +# from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream CFG = Config() @@ -20,9 +22,10 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048) "chatgpt_proxyllm": chatgpt_generate_stream, "bard_proxyllm": bard_generate_stream, "claude_proxyllm": claude_generate_stream, - "gpt4_proxyllm": gpt4_generate_stream, + # "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream "wenxin_proxyllm": wenxin_generate_stream, "tongyi_proxyllm": tongyi_generate_stream, + "zhipu_proxyllm": zhipu_generate_stream, } default_error_message = f"{CFG.LLM_MODEL} LLM is not supported" diff --git a/pilot/model/loader.py b/pilot/model/loader.py index f30a9811d..025e93091 100644 --- a/pilot/model/loader.py +++ b/pilot/model/loader.py @@ -8,6 +8,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType from pilot.model.parameter import ( ModelParameters, LlamaCppModelParameters, + ProxyModelParameters, ) from pilot.utils import get_gpu_memory from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case @@ -114,11 +115,12 @@ class ModelLoader: llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) model_type = llm_adapter.model_type() self.prompt_template = model_params.prompt_template - logger.info(f"model_params:\n{model_params}") if model_type == ModelType.HF: return huggingface_loader(llm_adapter, model_params) elif model_type == ModelType.LLAMA_CPP: return llamacpp_loader(llm_adapter, model_params) + elif model_type == ModelType.PROXY: + return proxyllm_loader(llm_adapter, model_params) else: raise Exception(f"Unkown model type {model_type}") @@ -346,3 +348,10 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam model_path = model_params.model_path model, tokenizer = LlamaCppModel.from_pretrained(model_path, model_params) return model, tokenizer + + +def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters): + from pilot.model.proxy.llms.proxy_model import ProxyModel + + model = ProxyModel(model_params) + return model, model diff --git a/pilot/model/parameter.py b/pilot/model/parameter.py index bddea6663..c140f7e35 100644 --- a/pilot/model/parameter.py +++ b/pilot/model/parameter.py @@ -27,6 +27,9 @@ class ModelControllerParameters(BaseParameters): port: Optional[int] = field( default=8000, metadata={"help": "Model Controller deploy port"} ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Model Controller in background"} + ) @dataclass @@ -50,6 +53,9 @@ class ModelWorkerParameters(BaseParameters): port: Optional[int] = field( default=8000, metadata={"help": "Model worker deploy port"} ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Model Worker in background"} + ) limit_model_concurrency: Optional[int] = field( default=5, metadata={"help": "Model concurrency limit"} ) @@ -109,9 +115,13 @@ class EmbeddingModelParameters(BaseParameters): @dataclass -class ModelParameters(BaseParameters): +class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) + + +@dataclass +class ModelParameters(BaseModelParameters): device: Optional[str] = field( default=None, metadata={ @@ -120,7 +130,10 @@ class ModelParameters(BaseParameters): ) model_type: Optional[str] = field( default="huggingface", - metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"}, + metadata={ + "help": "Model type, huggingface, llama.cpp and proxy", + "tags": "fixed", + }, ) prompt_template: Optional[str] = field( default=None, @@ -221,3 +234,43 @@ class LlamaCppModelParameters(ModelParameters): "help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured." }, ) + + +@dataclass +class ProxyModelParameters(BaseModelParameters): + proxy_server_url: str = field( + metadata={ + "help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions" + }, + ) + proxy_api_key: str = field( + metadata={"tags": "privacy", "help": "The api key of current proxy LLM"}, + ) + proxyllm_backend: Optional[str] = field( + default=None, + metadata={ + "help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on" + }, + ) + model_type: Optional[str] = field( + default="proxy", + metadata={ + "help": "Model type, huggingface, llama.cpp and proxy", + "tags": "fixed", + }, + ) + device: Optional[str] = field( + default=None, + metadata={ + "help": "Device to run model. If None, the device is automatically determined" + }, + ) + prompt_template: Optional[str] = field( + default=None, + metadata={ + "help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}" + }, + ) + max_context_size: Optional[int] = field( + default=4096, metadata={"help": "Maximum context size"} + ) diff --git a/pilot/model/proxy/llms/bard.py b/pilot/model/proxy/llms/bard.py index 73f959512..7590547f9 100644 --- a/pilot/model/proxy/llms/bard.py +++ b/pilot/model/proxy/llms/bard.py @@ -1,13 +1,19 @@ import bardapi import requests from typing import List -from pilot.configs.config import Config from pilot.scene.base_message import ModelMessage, ModelMessageRoleType - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def bard_generate_stream(model, tokenizer, params, device, context_len=2048): +def bard_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + + proxy_api_key = model_params.proxy_api_key + proxy_server_url = model_params.proxy_server_url + history = [] messages: List[ModelMessage] = params["messages"] for message in messages: @@ -35,18 +41,18 @@ def bard_generate_stream(model, tokenizer, params, device, context_len=2048): if msg.get("content"): msgs.append(msg["content"]) - if CFG.proxy_server_url is not None: + if proxy_server_url is not None: headers = {"Content-Type": "application/json"} payloads = {"input": "\n".join(msgs)} response = requests.post( - CFG.proxy_server_url, headers=headers, json=payloads, stream=False + proxy_server_url, headers=headers, json=payloads, stream=False ) if response.ok: yield response.text else: yield f"bard proxy url request failed!, response = {str(response)}" else: - response = bardapi.core.Bard(CFG.bard_proxy_api_key).get_answer("\n".join(msgs)) + response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs)) if response is not None and response.get("content") is not None: yield str(response["content"]) diff --git a/pilot/model/proxy/llms/chatgpt.py b/pilot/model/proxy/llms/chatgpt.py index cf61e78ac..70dea67f0 100644 --- a/pilot/model/proxy/llms/chatgpt.py +++ b/pilot/model/proxy/llms/chatgpt.py @@ -4,18 +4,27 @@ import json import requests from typing import List -from pilot.configs.config import Config from pilot.scene.base_message import ModelMessage, ModelMessageRoleType - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048): +def chatgpt_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): history = [] + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + + proxy_api_key = model_params.proxy_api_key + proxy_server_url = model_params.proxy_server_url + proxyllm_backend = model_params.proxyllm_backend + if not proxyllm_backend: + proxyllm_backend = "gpt-3.5-turbo" + headers = { - "Authorization": "Bearer " + CFG.proxy_api_key, - "Token": CFG.proxy_api_key, + "Authorization": "Bearer " + proxy_api_key, + "Token": proxy_api_key, } messages: List[ModelMessage] = params["messages"] @@ -42,16 +51,16 @@ def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048): history.append(last_user_input) payloads = { - "model": "gpt-3.5-turbo", # just for test, remove this later + "model": proxyllm_backend, # just for test, remove this later "messages": history, "temperature": params.get("temperature"), "max_tokens": params.get("max_new_tokens"), "stream": True, } - res = requests.post( - CFG.proxy_server_url, headers=headers, json=payloads, stream=True - ) + res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True) + + print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}") text = "" for line in res.iter_lines(): diff --git a/pilot/model/proxy/llms/claude.py b/pilot/model/proxy/llms/claude.py index 236df9e54..c14e6a7e4 100644 --- a/pilot/model/proxy/llms/claude.py +++ b/pilot/model/proxy/llms/claude.py @@ -1,7 +1,7 @@ -from pilot.configs.config import Config - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def claude_generate_stream(model, tokenizer, params, device, context_len=2048): +def claude_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): yield "claude LLM was not supported!" diff --git a/pilot/model/proxy/llms/gpt4.py b/pilot/model/proxy/llms/gpt4.py deleted file mode 100644 index 075351266..000000000 --- a/pilot/model/proxy/llms/gpt4.py +++ /dev/null @@ -1,7 +0,0 @@ -from pilot.configs.config import Config - -CFG = Config() - - -def gpt4_generate_stream(model, tokenizer, params, device, context_len=2048): - yield "gpt4 LLM was not supported!" diff --git a/pilot/model/proxy/llms/proxy_model.py b/pilot/model/proxy/llms/proxy_model.py new file mode 100644 index 000000000..587bb1360 --- /dev/null +++ b/pilot/model/proxy/llms/proxy_model.py @@ -0,0 +1,9 @@ +from pilot.model.parameter import ProxyModelParameters + + +class ProxyModel: + def __init__(self, model_params: ProxyModelParameters) -> None: + self._model_params = model_params + + def get_params(self) -> ProxyModelParameters: + return self._model_params diff --git a/pilot/model/proxy/llms/tongyi.py b/pilot/model/proxy/llms/tongyi.py index 9b056f1ef..0c530cdcf 100644 --- a/pilot/model/proxy/llms/tongyi.py +++ b/pilot/model/proxy/llms/tongyi.py @@ -1,7 +1,7 @@ -from pilot.configs.config import Config - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def tongyi_generate_stream(model, tokenizer, params, device, context_len=2048): +def tongyi_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): yield "tongyi LLM was not supported!" diff --git a/pilot/model/proxy/llms/wenxin.py b/pilot/model/proxy/llms/wenxin.py index a3e874928..6c621476e 100644 --- a/pilot/model/proxy/llms/wenxin.py +++ b/pilot/model/proxy/llms/wenxin.py @@ -1,7 +1,7 @@ -from pilot.configs.config import Config - -CFG = Config() +from pilot.model.proxy.llms.proxy_model import ProxyModel -def wenxin_generate_stream(model, tokenizer, params, device, context_len=2048): +def wenxin_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): yield "wenxin LLM is not supported!" diff --git a/pilot/model/proxy/llms/zhipu.py b/pilot/model/proxy/llms/zhipu.py new file mode 100644 index 000000000..50d2b4080 --- /dev/null +++ b/pilot/model/proxy/llms/zhipu.py @@ -0,0 +1,18 @@ +from pilot.model.proxy.llms.proxy_model import ProxyModel + + +def zhipu_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + """Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview""" + model_params = model.get_params() + print(f"Model: {model}, model_params: {model_params}") + + proxy_api_key = model_params.proxy_api_key + proxy_server_url = model_params.proxy_server_url + proxyllm_backend = model_params.proxyllm_backend + + if not proxyllm_backend: + proxyllm_backend = "chatglm_pro" + # TODO + yield "Zhipu LLM was not supported!" diff --git a/pilot/model/worker/default_worker.py b/pilot/model/worker/default_worker.py index fad8d0b5e..331be53ec 100644 --- a/pilot/model/worker/default_worker.py +++ b/pilot/model/worker/default_worker.py @@ -33,6 +33,9 @@ class DefaultModelWorker(ModelWorker): self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path) model_type = self.llm_adapter.model_type() self.param_cls = self.llm_adapter.model_param_class(model_type) + logger.info( + f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}" + ) self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path) self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func( diff --git a/pilot/model/worker/manager.py b/pilot/model/worker/manager.py index 79c455f54..cb76f2529 100644 --- a/pilot/model/worker/manager.py +++ b/pilot/model/worker/manager.py @@ -309,7 +309,7 @@ class LocalWorkerManager(WorkerManager): else: # Apply to all workers worker_instances = list(itertools.chain(*self.workers.values())) - logger.info(f"Apply to all workers: {worker_instances}") + logger.info(f"Apply to all workers") return await asyncio.gather( *(apply_func(worker) for worker in worker_instances) ) diff --git a/pilot/scripts/cli_scripts.py b/pilot/scripts/cli_scripts.py index b785763af..6a0bcbb5e 100644 --- a/pilot/scripts/cli_scripts.py +++ b/pilot/scripts/cli_scripts.py @@ -61,8 +61,6 @@ try: stop_model_controller, start_model_worker, stop_model_worker, - start_webserver, - stop_webserver, start_apiserver, stop_apiserver, ) @@ -70,17 +68,24 @@ try: add_command_alias(model_cli_group, name="model", parent_group=cli) add_command_alias(start_model_controller, name="controller", parent_group=start) add_command_alias(start_model_worker, name="worker", parent_group=start) - add_command_alias(start_webserver, name="webserver", parent_group=start) add_command_alias(start_apiserver, name="apiserver", parent_group=start) add_command_alias(stop_model_controller, name="controller", parent_group=stop) add_command_alias(stop_model_worker, name="worker", parent_group=stop) - add_command_alias(stop_webserver, name="webserver", parent_group=stop) add_command_alias(stop_apiserver, name="apiserver", parent_group=stop) except ImportError as e: logging.warning(f"Integrating dbgpt model command line tool failed: {e}") +try: + from pilot.server._cli import start_webserver, stop_webserver + + add_command_alias(start_webserver, name="webserver", parent_group=start) + add_command_alias(stop_webserver, name="webserver", parent_group=stop) + +except ImportError as e: + logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}") + try: from pilot.server.knowledge._cli.knowledge_cli import knowledge_cli_group diff --git a/pilot/server/_cli.py b/pilot/server/_cli.py new file mode 100644 index 000000000..90993d624 --- /dev/null +++ b/pilot/server/_cli.py @@ -0,0 +1,32 @@ +import click +import os +from pilot.server.base import WebWerverParameters +from pilot.configs.model_config import LOGDIR +from pilot.utils.parameter_utils import EnvArgumentParser +from pilot.utils.command_utils import _run_current_with_daemon, _stop_service + + +@click.command(name="webserver") +@EnvArgumentParser.create_click_option(WebWerverParameters) +def start_webserver(**kwargs): + """Start webserver(dbgpt_server.py)""" + if kwargs["daemon"]: + log_file = os.path.join(LOGDIR, "webserver_uvicorn.log") + _run_current_with_daemon("WebServer", log_file) + else: + from pilot.server.dbgpt_server import run_webserver + + run_webserver(WebWerverParameters(**kwargs)) + + +@click.command(name="webserver") +@click.option( + "--port", + type=int, + default=None, + required=False, + help=("The port to stop"), +) +def stop_webserver(port: int): + """Stop webserver(dbgpt_server.py)""" + _stop_service("webserver", "WebServer", port=port) diff --git a/pilot/server/base.py b/pilot/server/base.py index 1b7ce6ecd..6b44e8472 100644 --- a/pilot/server/base.py +++ b/pilot/server/base.py @@ -2,13 +2,12 @@ import signal import os import threading import sys +from typing import Optional +from dataclasses import dataclass, field -from pilot.summary.db_summary_client import DBSummaryClient -from pilot.commands.command_mange import CommandRegistry from pilot.configs.config import Config +from pilot.utils.parameter_utils import BaseParameters -from pilot.common.plugins import scan_plugins -from pilot.connections.manages.connection_manager import ConnectManager ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -20,12 +19,18 @@ def signal_handler(sig, frame): def async_db_summery(): + from pilot.summary.db_summary_client import DBSummaryClient + client = DBSummaryClient() thread = threading.Thread(target=client.init_db_summary) thread.start() def server_init(args): + from pilot.commands.command_mange import CommandRegistry + from pilot.connections.manages.connection_manager import ConnectManager + from pilot.common.plugins import scan_plugins + # logger.info(f"args: {args}") # init config @@ -63,3 +68,38 @@ def server_init(args): for command in command_disply_commands: command_disply_registry.import_commands(command) cfg.command_disply = command_disply_registry + + +@dataclass +class WebWerverParameters(BaseParameters): + host: Optional[str] = field( + default="0.0.0.0", metadata={"help": "Webserver deploy host"} + ) + port: Optional[int] = field( + default=5000, metadata={"help": "Webserver deploy port"} + ) + daemon: Optional[bool] = field( + default=False, metadata={"help": "Run Webserver in background"} + ) + share: Optional[bool] = field( + default=False, + metadata={ + "help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. " + }, + ) + log_level: Optional[str] = field( + default="INFO", + metadata={ + "help": "Logging level", + "valid_values": [ + "FATAL", + "ERROR", + "WARNING", + "WARNING", + "INFO", + "DEBUG", + "NOTSET", + ], + }, + ) + light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"}) diff --git a/pilot/server/dbgpt_server.py b/pilot/server/dbgpt_server.py index 8a172120f..440dd20e3 100644 --- a/pilot/server/dbgpt_server.py +++ b/pilot/server/dbgpt_server.py @@ -1,7 +1,7 @@ import os import argparse import sys -import logging +from typing import List ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(ROOT_PATH) @@ -9,7 +9,7 @@ import signal from pilot.configs.config import Config from pilot.configs.model_config import LLM_MODEL_CONFIG -from pilot.server.base import server_init +from pilot.server.base import server_init, WebWerverParameters from fastapi.staticfiles import StaticFiles from fastapi import FastAPI, applications @@ -24,7 +24,7 @@ from pilot.openapi.base import validation_exception_handler from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1 from pilot.commands.disply_type.show_chart_gen import static_message_img_path from pilot.model.worker.manager import initialize_worker_manager_in_client -from pilot.utils.utils import setup_logging +from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level static_file_path = os.path.join(os.getcwd(), "server/static") @@ -84,33 +84,24 @@ def mount_static_files(app): app.add_exception_handler(RequestValidationError, validation_exception_handler) -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--model_list_mode", type=str, default="once", choices=["once", "reload"] - ) - # old version server config - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=5000) - parser.add_argument("--concurrency-count", type=int, default=10) - parser.add_argument("--share", default=False, action="store_true") - parser.add_argument("--log-level", type=str, default=None) - parser.add_argument( - "-light", - "--light", - default=False, - action="store_true", - help="enable light mode", - ) +def initialize_app(param: WebWerverParameters = None, args: List[str] = None): + """Initialize app + If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook. + """ + if not param: + from pilot.utils.parameter_utils import EnvArgumentParser - # init server config - args = parser.parse_args() - setup_logging(logging_level=args.log_level) - server_init(args) + parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option( + WebWerverParameters + ) + param = WebWerverParameters(**vars(parser.parse_args(args=args))) + + setup_logging(logging_level=param.log_level) + server_init(param) model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL] - if not args.light: + if not param.light: print("Model Unified Deployment Mode!") initialize_worker_manager_in_client( app=app, model_name=CFG.LLM_MODEL, model_path=model_path @@ -129,7 +120,25 @@ if __name__ == "__main__": CFG.SERVER_LIGHT_MODE = True mount_static_files(app) + return param + + +def run_uvicorn(param: WebWerverParameters): import uvicorn - uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info") + uvicorn.run( + app, + host=param.host, + port=param.port, + log_level=logging_str_to_uvicorn_level(param.log_level), + ) signal.signal(signal.SIGINT, signal_handler()) + + +def run_webserver(param: WebWerverParameters = None): + param = initialize_app(param) + run_uvicorn(param) + + +if __name__ == "__main__": + run_webserver() diff --git a/pilot/utils/command_utils.py b/pilot/utils/command_utils.py new file mode 100644 index 000000000..f29aff3e1 --- /dev/null +++ b/pilot/utils/command_utils.py @@ -0,0 +1,91 @@ +import sys +import os +import subprocess +from typing import List, Dict +import psutil +import platform + + +def _run_current_with_daemon(name: str, log_file: str): + # Get all arguments except for --daemon + args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"] + daemon_cmd = [sys.executable] + args + daemon_cmd = " ".join(daemon_cmd) + daemon_cmd += f" > {log_file} 2>&1" + + # Check the platform and set the appropriate flags or functions + if platform.system() == "Windows": + process = subprocess.Popen( + daemon_cmd, + creationflags=subprocess.CREATE_NEW_PROCESS_GROUP, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=True, + ) + else: # macOS, Linux, and other Unix-like systems + process = subprocess.Popen( + daemon_cmd, + preexec_fn=os.setsid, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + shell=True, + ) + + print(f"Started {name} in background with pid: {process.pid}") + + +def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict): + try: + import gunicorn + except ImportError as e: + raise ValueError( + "Could not import python package: gunicorn" + "Daemon mode need install gunicorn, please install `pip install gunicorn`" + ) from e + + from pilot.utils.parameter_utils import EnvArgumentParser + + env_to_app = {} + env_to_app.update(os.environ) + app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs) + env_to_app.update(app_env) + cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000" + if platform.system() == "Windows": + raise Exception("Not support on windows") + else: # macOS, Linux, and other Unix-like systems + process = subprocess.Popen(cmd, shell=True, env=env_to_app) + print(f"Started {app} with gunicorn in background with pid: {process.pid}") + + +def _stop_service( + key: str, fullname: str, service_keys: List[str] = None, port: int = None +): + if not service_keys: + service_keys = [sys.argv[0], "start", key] + not_found = True + for process in psutil.process_iter(attrs=["pid", "connections", "cmdline"]): + try: + cmdline = " ".join(process.info["cmdline"]) + + # Check if all key fragments are in the cmdline + if all(fragment in cmdline for fragment in service_keys): + if port: + for conn in process.info["connections"]: + if ( + conn.status == psutil.CONN_LISTEN + and conn.laddr.port == port + ): + psutil.Process(process.info["pid"]).terminate() + print( + f"Terminated the {fullname} with PID: {process.info['pid']} listening on port: {port}" + ) + not_found = False + else: + psutil.Process(process.info["pid"]).terminate() + print(f"Terminated the {fullname} with PID: {process.info['pid']}") + not_found = False + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + if not_found: + print(f"{fullname} process not found.") diff --git a/pilot/utils/parameter_utils.py b/pilot/utils/parameter_utils.py index 026e1bec1..a584a4e69 100644 --- a/pilot/utils/parameter_utils.py +++ b/pilot/utils/parameter_utils.py @@ -1,7 +1,8 @@ import argparse import os from dataclasses import dataclass, fields, MISSING -from typing import Any, List, Optional, Type, Union, Callable +from typing import Any, List, Optional, Type, Union, Callable, Dict +from collections import OrderedDict @dataclass @@ -60,7 +61,7 @@ class BaseParameters: f"\n\n=========================== {class_name} ===========================\n" ] for field_info in fields(self): - value = getattr(self, field_info.name) + value = _get_simple_privacy_field_value(self, field_info) parameters.append(f"{field_info.name}: {value}") parameters.append( "\n======================================================================\n\n" @@ -68,6 +69,55 @@ class BaseParameters: return "\n".join(parameters) +def _get_simple_privacy_field_value(obj, field_info): + """Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary. + + This function reads the metadata of a field to check if it's tagged with 'privacy'. + If the 'privacy' tag is present, then it modifies the value based on its type + for privacy concerns: + - int: returns -999 + - float: returns -999.0 + - bool: returns False + - str: if length > 5, masks the middle part and returns first and last char; + otherwise, returns "******" + + Parameters: + - obj: The dataclass instance. + - field_info: A Field object that contains information about the dataclass field. + + Returns: + The original or modified value of the field based on the privacy rules. + + Example usage: + @dataclass + class Person: + name: str + age: int + ssn: str = field(metadata={"tags": "privacy"}) + p = Person("Alice", 30, "123-45-6789") + print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9 + """ + tags = field_info.metadata.get("tags") + tags = [] if not tags else tags.split(",") + is_privacy = False + if tags and "privacy" in tags: + is_privacy = True + value = getattr(obj, field_info.name) + if not is_privacy or not value: + return value + field_type = EnvArgumentParser._get_argparse_type(field_info.type) + if field_type is int: + return -999 + if field_type is float: + return -999.0 + if field_type is bool: + return False + # str + if len(value) > 5: + return value[0] + "******" + value[-1] + return "******" + + def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None): """Get the value from the environment variable, ignoring the case of the key""" if env_prefix: @@ -115,24 +165,11 @@ class EnvArgumentParser: if not env_var_value: env_var_value = kwargs.get(field.name) + print(f"env_var_value: {env_var_value} for {field.name}") # Add a command-line argument for this field - help_text = field.metadata.get("help", "") - valid_values = field.metadata.get("valid_values", None) - - argument_kwargs = { - "type": EnvArgumentParser._get_argparse_type(field.type), - "help": help_text, - "choices": valid_values, - "required": EnvArgumentParser._is_require_type(field.type), - } - if field.default != MISSING: - argument_kwargs["default"] = field.default - argument_kwargs["required"] = False - if env_var_value: - argument_kwargs["default"] = env_var_value - argument_kwargs["required"] = False - - parser.add_argument(f"--{field.name}", **argument_kwargs) + EnvArgumentParser._build_single_argparse_option( + parser, field, env_var_value + ) # Parse the command-line arguments cmd_args, cmd_argv = parser.parse_known_args(args=command_args) @@ -148,7 +185,7 @@ class EnvArgumentParser: return dataclass_type(**kwargs) @staticmethod - def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser: + def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description=dataclass_type.__doc__) for field in fields(dataclass_type): help_text = field.metadata.get("help", "") @@ -173,11 +210,6 @@ class EnvArgumentParser: import functools from collections import OrderedDict - # TODO dynamic configuration - # pre_args = _SimpleArgParser('model_name', 'model_path') - # pre_args.parse() - # print(pre_args) - combined_fields = OrderedDict() if _dynamic_factory: _types = _dynamic_factory() @@ -225,6 +257,48 @@ class EnvArgumentParser: return decorator + @staticmethod + def create_argparse_option( + *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None + ) -> argparse.ArgumentParser: + combined_fields = _merge_dataclass_types( + *dataclass_types, _dynamic_factory=_dynamic_factory + ) + parser = argparse.ArgumentParser() + for _, field in reversed(combined_fields.items()): + EnvArgumentParser._build_single_argparse_option(parser, field) + return parser + + @staticmethod + def _build_single_argparse_option( + parser: argparse.ArgumentParser, field, default_value=None + ): + # Add a command-line argument for this field + help_text = field.metadata.get("help", "") + valid_values = field.metadata.get("valid_values", None) + short_name = field.metadata.get("short", None) + argument_kwargs = { + "type": EnvArgumentParser._get_argparse_type(field.type), + "help": help_text, + "choices": valid_values, + "required": EnvArgumentParser._is_require_type(field.type), + } + if field.default != MISSING: + argument_kwargs["default"] = field.default + argument_kwargs["required"] = False + if default_value: + argument_kwargs["default"] = default_value + argument_kwargs["required"] = False + if field.type is bool or field.type == Optional[bool]: + argument_kwargs["action"] = "store_true" + del argument_kwargs["type"] + del argument_kwargs["choices"] + names = [] + if short_name: + names.append(f"-{short_name}") + names.append(f"--{field.name}") + parser.add_argument(*names, **argument_kwargs) + @staticmethod def _get_argparse_type(field_type: Type) -> Type: # Return the appropriate type for argparse to use based on the field type @@ -255,6 +329,42 @@ class EnvArgumentParser: def _is_require_type(field_type: Type) -> str: return field_type not in [Optional[int], Optional[float], Optional[bool]] + @staticmethod + def _kwargs_to_env_key_value( + kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__" + ) -> Dict[str, str]: + return {prefix + k: str(v) for k, v in kwargs.items()} + + @staticmethod + def _read_env_key_value( + prefix: str = "__dbgpt_gunicorn__env_prefix__", + ) -> List[str]: + env_args = [] + for key, value in os.environ.items(): + if key.startswith(prefix): + arg_key = "--" + key.replace(prefix, "") + if value.lower() in ["true", "1"]: + # Flag args + env_args.append(arg_key) + elif not value.lower() in ["false", "0"]: + env_args.extend([arg_key, value]) + return env_args + + +def _merge_dataclass_types( + *dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None +) -> OrderedDict: + combined_fields = OrderedDict() + if _dynamic_factory: + _types = _dynamic_factory() + if _types: + dataclass_types = list(_types) + for dataclass_type in dataclass_types: + for field in fields(dataclass_type): + if field.name not in combined_fields: + combined_fields[field.name] = field + return combined_fields + def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]: descriptions = [] @@ -297,7 +407,7 @@ class _SimpleArgParser: self.params[prev_arg] = None def _get_param(self, key): - return self.params.get(key.replace("_", "-"), None) + return self.params.get(key.replace("_", "-")) or self.params.get(key) def __getattr__(self, item): return self._get_param(item) diff --git a/pilot/utils/utils.py b/pilot/utils/utils.py index 2e205c275..f3cac70d2 100644 --- a/pilot/utils/utils.py +++ b/pilot/utils/utils.py @@ -161,3 +161,15 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop: raise e logging.warning("Cant not get running event loop, create new event loop now") return asyncio.get_event_loop_policy().get_event_loop() + + +def logging_str_to_uvicorn_level(log_level_str): + level_str_mapping = { + "CRITICAL": "critical", + "ERROR": "error", + "WARNING": "warning", + "INFO": "info", + "DEBUG": "debug", + "NOTSET": "info", + } + return level_str_mapping.get(log_level_str.upper(), "info")