mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat: Multi-model support with proxyllm and add more command-cli
This commit is contained in:
parent
b8f09df45e
commit
b6a4fd8a62
@ -18,7 +18,7 @@ services:
|
||||
networks:
|
||||
- dbgptnet
|
||||
webserver:
|
||||
image: eosphorosai/dbgpt-allinone:latest
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: python3 pilot/server/dbgpt_server.py
|
||||
environment:
|
||||
- LOCAL_DB_HOST=db
|
||||
@ -46,7 +46,6 @@ services:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
device_ids: ['0']
|
||||
capabilities: [gpu]
|
||||
volumes:
|
||||
dbgpt-myql-db:
|
||||
|
@ -55,6 +55,7 @@ RUN (if [ "${LOAD_EXAMPLES}" = "true" ]; \
|
||||
&& sqlite3 /app/pilot/data/default_sqlite.db < /app/docker/examples/sqls/test_case_info_sqlite.sql; \
|
||||
fi;)
|
||||
|
||||
ENV PYTHONPATH "/app:$PYTHONPATH"
|
||||
EXPOSE 5000
|
||||
|
||||
CMD ["python3", "pilot/server/dbgpt_server.py"]
|
68
docker/compose_examples/cluster-docker-compose.yml
Normal file
68
docker/compose_examples/cluster-docker-compose.yml
Normal file
@ -0,0 +1,68 @@
|
||||
version: '3.10'
|
||||
|
||||
services:
|
||||
controller:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start controller
|
||||
networks:
|
||||
- dbgptnet
|
||||
worker:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start worker --model_name vicuna-13b-v1.5 --model_path /app/models/vicuna-13b-v1.5 --port 8001 --controller_addr http://controller:8000
|
||||
environment:
|
||||
- DBGPT_LOG_LEVEL=DEBUG
|
||||
depends_on:
|
||||
- controller
|
||||
volumes:
|
||||
- /data:/data
|
||||
# Please modify it to your own model directory
|
||||
- /data/models:/app/models
|
||||
networks:
|
||||
- dbgptnet
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
webserver:
|
||||
image: eosphorosai/dbgpt:latest
|
||||
command: dbgpt start webserver --light
|
||||
environment:
|
||||
- DBGPT_LOG_LEVEL=DEBUG
|
||||
- LOCAL_DB_PATH=data/default_sqlite.db
|
||||
- LOCAL_DB_TYPE=sqlite
|
||||
- ALLOWLISTED_PLUGINS=db_dashboard
|
||||
- LLM_MODEL=vicuna-13b-v1.5
|
||||
- MODEL_SERVER=http://controller:8000
|
||||
depends_on:
|
||||
- controller
|
||||
- worker
|
||||
volumes:
|
||||
- /data:/data
|
||||
# Please modify it to your own model directory
|
||||
- /data/models:/app/models
|
||||
- dbgpt-data:/app/pilot/data
|
||||
- dbgpt-message:/app/pilot/message
|
||||
# env_file:
|
||||
# - .env.template
|
||||
ports:
|
||||
- 5000:5000/tcp
|
||||
# webserver may be failed, it must wait all sqls in /docker-entrypoint-initdb.d execute finish.
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- dbgptnet
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
volumes:
|
||||
dbgpt-myql-db:
|
||||
dbgpt-data:
|
||||
dbgpt-message:
|
||||
networks:
|
||||
dbgptnet:
|
||||
driver: bridge
|
||||
name: dbgptnet
|
@ -41,6 +41,10 @@ class Config(metaclass=Singleton):
|
||||
# This is a proxy server, just for test_py. we will remove this later.
|
||||
self.proxy_api_key = os.getenv("PROXY_API_KEY")
|
||||
self.bard_proxy_api_key = os.getenv("BARD_PROXY_API_KEY")
|
||||
# In order to be compatible with the new and old model parameter design
|
||||
if self.bard_proxy_api_key:
|
||||
os.environ["bard_proxyllm_proxy_api_key"] = self.bard_proxy_api_key
|
||||
|
||||
self.proxy_server_url = os.getenv("PROXY_SERVER_URL")
|
||||
|
||||
self.elevenlabs_api_key = os.getenv("ELEVENLABS_API_KEY")
|
||||
|
@ -70,7 +70,7 @@ LLM_MODEL_CONFIG = {
|
||||
"claude_proxyllm": "claude_proxyllm",
|
||||
"wenxin_proxyllm": "wenxin_proxyllm",
|
||||
"tongyi_proxyllm": "tongyi_proxyllm",
|
||||
"gpt4_proxyllm": "gpt4_proxyllm",
|
||||
"zhipu_proxyllm": "zhipu_proxyllm",
|
||||
"llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"),
|
||||
"llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"),
|
||||
"llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"),
|
||||
|
@ -4,7 +4,7 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
from typing import List, Tuple, Callable, Type
|
||||
from functools import cache
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
@ -12,7 +12,11 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
LlamaTokenizer,
|
||||
)
|
||||
from pilot.model.parameter import ModelParameters, LlamaCppModelParameters
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from pilot.configs.model_config import get_device
|
||||
from pilot.configs.config import Config
|
||||
from pilot.logs import logger
|
||||
@ -26,6 +30,7 @@ class ModelType:
|
||||
|
||||
HF = "huggingface"
|
||||
LLAMA_CPP = "llama.cpp"
|
||||
PROXY = "proxy"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
@ -43,6 +48,8 @@ class BaseLLMAdaper:
|
||||
model_type = model_type if model_type else self.model_type()
|
||||
if model_type == ModelType.LLAMA_CPP:
|
||||
return LlamaCppModelParameters
|
||||
elif model_type == ModelType.PROXY:
|
||||
return ProxyModelParameters
|
||||
return ModelParameters
|
||||
|
||||
def match(self, model_path: str):
|
||||
@ -76,7 +83,7 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
return adapter
|
||||
|
||||
for adapter in llm_model_adapters:
|
||||
if adapter.match(model_path):
|
||||
if model_path and adapter.match(model_path):
|
||||
logger.info(
|
||||
f"Found llm model adapter with model path: {model_path}, {adapter}"
|
||||
)
|
||||
@ -87,6 +94,20 @@ def get_llm_model_adapter(model_name: str, model_path: str) -> BaseLLMAdaper:
|
||||
)
|
||||
|
||||
|
||||
def _dynamic_model_parser() -> Callable[[None], List[Type]]:
|
||||
from pilot.utils.parameter_utils import _SimpleArgParser
|
||||
|
||||
pre_args = _SimpleArgParser("model_name", "model_path")
|
||||
pre_args.parse()
|
||||
model_name = pre_args.get("model_name")
|
||||
model_path = pre_args.get("model_path")
|
||||
if model_name is None:
|
||||
return None
|
||||
llm_adapter = get_llm_model_adapter(model_name, model_path)
|
||||
param_class = llm_adapter.model_param_class()
|
||||
return [param_class]
|
||||
|
||||
|
||||
# TODO support cpu? for practise we support gpt4all or chatglm-6b-int4?
|
||||
|
||||
|
||||
@ -281,6 +302,9 @@ class GPT4AllAdapter(BaseLLMAdaper):
|
||||
class ProxyllmAdapter(BaseLLMAdaper):
|
||||
"""The model adapter for local proxy"""
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.PROXY
|
||||
|
||||
def match(self, model_path: str):
|
||||
return "proxyllm" in model_path
|
||||
|
||||
|
@ -1,8 +1,11 @@
|
||||
import click
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, List, Type
|
||||
|
||||
from pilot.model.controller.registry import ModelRegistryClient
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.model.base import WorkerApplyType
|
||||
from pilot.model.parameter import (
|
||||
ModelControllerParameters,
|
||||
@ -11,6 +14,8 @@ from pilot.model.parameter import (
|
||||
)
|
||||
from pilot.utils import get_or_create_event_loop
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
|
||||
|
||||
|
||||
MODEL_CONTROLLER_ADDRESS = "http://127.0.0.1:8000"
|
||||
|
||||
@ -157,46 +162,81 @@ def worker_apply(
|
||||
print(res)
|
||||
|
||||
|
||||
def add_stop_server_options(func):
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("The port to stop"),
|
||||
)
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
@EnvArgumentParser.create_click_option(ModelControllerParameters)
|
||||
def start_model_controller(**kwargs):
|
||||
"""Start model controller"""
|
||||
|
||||
from pilot.model.controller.controller import run_model_controller
|
||||
|
||||
run_model_controller()
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "model_controller_uvicorn.log")
|
||||
_run_current_with_daemon("ModelController", log_file)
|
||||
else:
|
||||
from pilot.model.controller.controller import run_model_controller
|
||||
|
||||
run_model_controller()
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
def stop_model_controller(**kwargs):
|
||||
@add_stop_server_options
|
||||
def stop_model_controller(port: int):
|
||||
"""Start model controller"""
|
||||
raise NotImplementedError
|
||||
# Command fragments to check against running processes
|
||||
_stop_service("controller", "ModelController", port=port)
|
||||
|
||||
|
||||
def _model_dynamic_factory() -> Callable[[None], List[Type]]:
|
||||
from pilot.model.adapter import _dynamic_model_parser
|
||||
|
||||
param_class = _dynamic_model_parser()
|
||||
fix_class = [ModelWorkerParameters]
|
||||
if not param_class:
|
||||
param_class = [ModelParameters]
|
||||
fix_class += param_class
|
||||
return fix_class
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
@EnvArgumentParser.create_click_option(ModelWorkerParameters, ModelParameters)
|
||||
@EnvArgumentParser.create_click_option(
|
||||
ModelWorkerParameters, ModelParameters, _dynamic_factory=_model_dynamic_factory
|
||||
)
|
||||
def start_model_worker(**kwargs):
|
||||
"""Start model worker"""
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
if kwargs["daemon"]:
|
||||
port = kwargs["port"]
|
||||
model_type = kwargs.get("worker_type") or "llm"
|
||||
log_file = os.path.join(LOGDIR, f"model_worker_{model_type}_{port}_uvicorn.log")
|
||||
_run_current_with_daemon("ModelWorker", log_file)
|
||||
else:
|
||||
from pilot.model.worker.manager import run_worker_manager
|
||||
|
||||
run_worker_manager()
|
||||
run_worker_manager()
|
||||
|
||||
|
||||
@click.command(name="worker")
|
||||
def stop_model_worker(**kwargs):
|
||||
@add_stop_server_options
|
||||
def stop_model_worker(port: int):
|
||||
"""Stop model worker"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
def start_webserver(**kwargs):
|
||||
"""Start webserver(dbgpt_server.py)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
def stop_webserver(**kwargs):
|
||||
"""Stop webserver(dbgpt_server.py)"""
|
||||
raise NotImplementedError
|
||||
name = "ModelWorker"
|
||||
if port:
|
||||
name = f"{name}-{port}"
|
||||
_stop_service("worker", name, port=port)
|
||||
|
||||
|
||||
@click.command(name="apiserver")
|
||||
@ -205,7 +245,7 @@ def start_apiserver(**kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@click.command(name="controller")
|
||||
@click.command(name="apiserver")
|
||||
def stop_apiserver(**kwargs):
|
||||
"""Start apiserver(TODO)"""
|
||||
raise NotImplementedError
|
||||
|
@ -9,7 +9,9 @@ from pilot.model.proxy.llms.bard import bard_generate_stream
|
||||
from pilot.model.proxy.llms.claude import claude_generate_stream
|
||||
from pilot.model.proxy.llms.wenxin import wenxin_generate_stream
|
||||
from pilot.model.proxy.llms.tongyi import tongyi_generate_stream
|
||||
from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream
|
||||
from pilot.model.proxy.llms.zhipu import zhipu_generate_stream
|
||||
|
||||
# from pilot.model.proxy.llms.gpt4 import gpt4_generate_stream
|
||||
|
||||
CFG = Config()
|
||||
|
||||
@ -20,9 +22,10 @@ def proxyllm_generate_stream(model, tokenizer, params, device, context_len=2048)
|
||||
"chatgpt_proxyllm": chatgpt_generate_stream,
|
||||
"bard_proxyllm": bard_generate_stream,
|
||||
"claude_proxyllm": claude_generate_stream,
|
||||
"gpt4_proxyllm": gpt4_generate_stream,
|
||||
# "gpt4_proxyllm": gpt4_generate_stream, move to chatgpt_generate_stream
|
||||
"wenxin_proxyllm": wenxin_generate_stream,
|
||||
"tongyi_proxyllm": tongyi_generate_stream,
|
||||
"zhipu_proxyllm": zhipu_generate_stream,
|
||||
}
|
||||
|
||||
default_error_message = f"{CFG.LLM_MODEL} LLM is not supported"
|
||||
|
@ -8,6 +8,7 @@ from pilot.model.adapter import get_llm_model_adapter, BaseLLMAdaper, ModelType
|
||||
from pilot.model.parameter import (
|
||||
ModelParameters,
|
||||
LlamaCppModelParameters,
|
||||
ProxyModelParameters,
|
||||
)
|
||||
from pilot.utils import get_gpu_memory
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser, _genenv_ignoring_key_case
|
||||
@ -114,11 +115,12 @@ class ModelLoader:
|
||||
llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = llm_adapter.model_type()
|
||||
self.prompt_template = model_params.prompt_template
|
||||
logger.info(f"model_params:\n{model_params}")
|
||||
if model_type == ModelType.HF:
|
||||
return huggingface_loader(llm_adapter, model_params)
|
||||
elif model_type == ModelType.LLAMA_CPP:
|
||||
return llamacpp_loader(llm_adapter, model_params)
|
||||
elif model_type == ModelType.PROXY:
|
||||
return proxyllm_loader(llm_adapter, model_params)
|
||||
else:
|
||||
raise Exception(f"Unkown model type {model_type}")
|
||||
|
||||
@ -346,3 +348,10 @@ def llamacpp_loader(llm_adapter: BaseLLMAdaper, model_params: LlamaCppModelParam
|
||||
model_path = model_params.model_path
|
||||
model, tokenizer = LlamaCppModel.from_pretrained(model_path, model_params)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def proxyllm_loader(llm_adapter: BaseLLMAdaper, model_params: ProxyModelParameters):
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
model = ProxyModel(model_params)
|
||||
return model, model
|
||||
|
@ -27,6 +27,9 @@ class ModelControllerParameters(BaseParameters):
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model Controller deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model Controller in background"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -50,6 +53,9 @@ class ModelWorkerParameters(BaseParameters):
|
||||
port: Optional[int] = field(
|
||||
default=8000, metadata={"help": "Model worker deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Model Worker in background"}
|
||||
)
|
||||
limit_model_concurrency: Optional[int] = field(
|
||||
default=5, metadata={"help": "Model concurrency limit"}
|
||||
)
|
||||
@ -109,9 +115,13 @@ class EmbeddingModelParameters(BaseParameters):
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseParameters):
|
||||
class BaseModelParameters(BaseParameters):
|
||||
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
||||
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelParameters(BaseModelParameters):
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
@ -120,7 +130,10 @@ class ModelParameters(BaseParameters):
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default="huggingface",
|
||||
metadata={"help": "Model type, huggingface or llama.cpp", "tags": "fixed"},
|
||||
metadata={
|
||||
"help": "Model type, huggingface, llama.cpp and proxy",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
@ -221,3 +234,43 @@ class LlamaCppModelParameters(ModelParameters):
|
||||
"help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyModelParameters(BaseModelParameters):
|
||||
proxy_server_url: str = field(
|
||||
metadata={
|
||||
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
|
||||
},
|
||||
)
|
||||
proxy_api_key: str = field(
|
||||
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
||||
)
|
||||
proxyllm_backend: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model name actually pass to current proxy server url, such as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
default="proxy",
|
||||
metadata={
|
||||
"help": "Model type, huggingface, llama.cpp and proxy",
|
||||
"tags": "fixed",
|
||||
},
|
||||
)
|
||||
device: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Device to run model. If None, the device is automatically determined"
|
||||
},
|
||||
)
|
||||
prompt_template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": f"Prompt template. If None, the prompt template is automatically determined from model path, supported template: {suported_prompt_templates}"
|
||||
},
|
||||
)
|
||||
max_context_size: Optional[int] = field(
|
||||
default=4096, metadata={"help": "Maximum context size"}
|
||||
)
|
||||
|
@ -1,13 +1,19 @@
|
||||
import bardapi
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CFG = Config()
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
def bard_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
|
||||
history = []
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
for message in messages:
|
||||
@ -35,18 +41,18 @@ def bard_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
if msg.get("content"):
|
||||
msgs.append(msg["content"])
|
||||
|
||||
if CFG.proxy_server_url is not None:
|
||||
if proxy_server_url is not None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
payloads = {"input": "\n".join(msgs)}
|
||||
response = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=False
|
||||
proxy_server_url, headers=headers, json=payloads, stream=False
|
||||
)
|
||||
if response.ok:
|
||||
yield response.text
|
||||
else:
|
||||
yield f"bard proxy url request failed!, response = {str(response)}"
|
||||
else:
|
||||
response = bardapi.core.Bard(CFG.bard_proxy_api_key).get_answer("\n".join(msgs))
|
||||
response = bardapi.core.Bard(proxy_api_key).get_answer("\n".join(msgs))
|
||||
|
||||
if response is not None and response.get("content") is not None:
|
||||
yield str(response["content"])
|
||||
|
@ -4,18 +4,27 @@
|
||||
import json
|
||||
import requests
|
||||
from typing import List
|
||||
from pilot.configs.config import Config
|
||||
from pilot.scene.base_message import ModelMessage, ModelMessageRoleType
|
||||
|
||||
CFG = Config()
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
def chatgpt_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
history = []
|
||||
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = "gpt-3.5-turbo"
|
||||
|
||||
headers = {
|
||||
"Authorization": "Bearer " + CFG.proxy_api_key,
|
||||
"Token": CFG.proxy_api_key,
|
||||
"Authorization": "Bearer " + proxy_api_key,
|
||||
"Token": proxy_api_key,
|
||||
}
|
||||
|
||||
messages: List[ModelMessage] = params["messages"]
|
||||
@ -42,16 +51,16 @@ def chatgpt_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
history.append(last_user_input)
|
||||
|
||||
payloads = {
|
||||
"model": "gpt-3.5-turbo", # just for test, remove this later
|
||||
"model": proxyllm_backend, # just for test, remove this later
|
||||
"messages": history,
|
||||
"temperature": params.get("temperature"),
|
||||
"max_tokens": params.get("max_new_tokens"),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
res = requests.post(
|
||||
CFG.proxy_server_url, headers=headers, json=payloads, stream=True
|
||||
)
|
||||
res = requests.post(proxy_server_url, headers=headers, json=payloads, stream=True)
|
||||
|
||||
print(f"Send request to {proxy_server_url} with real model {proxyllm_backend}")
|
||||
|
||||
text = ""
|
||||
for line in res.iter_lines():
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def claude_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
def claude_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
yield "claude LLM was not supported!"
|
||||
|
@ -1,7 +0,0 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
|
||||
|
||||
def gpt4_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
yield "gpt4 LLM was not supported!"
|
9
pilot/model/proxy/llms/proxy_model.py
Normal file
9
pilot/model/proxy/llms/proxy_model.py
Normal file
@ -0,0 +1,9 @@
|
||||
from pilot.model.parameter import ProxyModelParameters
|
||||
|
||||
|
||||
class ProxyModel:
|
||||
def __init__(self, model_params: ProxyModelParameters) -> None:
|
||||
self._model_params = model_params
|
||||
|
||||
def get_params(self) -> ProxyModelParameters:
|
||||
return self._model_params
|
@ -1,7 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def tongyi_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
def tongyi_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
yield "tongyi LLM was not supported!"
|
||||
|
@ -1,7 +1,7 @@
|
||||
from pilot.configs.config import Config
|
||||
|
||||
CFG = Config()
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def wenxin_generate_stream(model, tokenizer, params, device, context_len=2048):
|
||||
def wenxin_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
yield "wenxin LLM is not supported!"
|
||||
|
18
pilot/model/proxy/llms/zhipu.py
Normal file
18
pilot/model/proxy/llms/zhipu.py
Normal file
@ -0,0 +1,18 @@
|
||||
from pilot.model.proxy.llms.proxy_model import ProxyModel
|
||||
|
||||
|
||||
def zhipu_generate_stream(
|
||||
model: ProxyModel, tokenizer, params, device, context_len=2048
|
||||
):
|
||||
"""Zhipu ai, see: https://open.bigmodel.cn/dev/api#overview"""
|
||||
model_params = model.get_params()
|
||||
print(f"Model: {model}, model_params: {model_params}")
|
||||
|
||||
proxy_api_key = model_params.proxy_api_key
|
||||
proxy_server_url = model_params.proxy_server_url
|
||||
proxyllm_backend = model_params.proxyllm_backend
|
||||
|
||||
if not proxyllm_backend:
|
||||
proxyllm_backend = "chatglm_pro"
|
||||
# TODO
|
||||
yield "Zhipu LLM was not supported!"
|
@ -33,6 +33,9 @@ class DefaultModelWorker(ModelWorker):
|
||||
self.llm_adapter = get_llm_model_adapter(self.model_name, self.model_path)
|
||||
model_type = self.llm_adapter.model_type()
|
||||
self.param_cls = self.llm_adapter.model_param_class(model_type)
|
||||
logger.info(
|
||||
f"model_name: {self.model_name}, model_path: {self.model_path}, model_param_class: {self.param_cls}"
|
||||
)
|
||||
|
||||
self.llm_chat_adapter = get_llm_chat_adapter(self.model_name, self.model_path)
|
||||
self.generate_stream_func = self.llm_chat_adapter.get_generate_stream_func(
|
||||
|
@ -309,7 +309,7 @@ class LocalWorkerManager(WorkerManager):
|
||||
else:
|
||||
# Apply to all workers
|
||||
worker_instances = list(itertools.chain(*self.workers.values()))
|
||||
logger.info(f"Apply to all workers: {worker_instances}")
|
||||
logger.info(f"Apply to all workers")
|
||||
return await asyncio.gather(
|
||||
*(apply_func(worker) for worker in worker_instances)
|
||||
)
|
||||
|
@ -61,8 +61,6 @@ try:
|
||||
stop_model_controller,
|
||||
start_model_worker,
|
||||
stop_model_worker,
|
||||
start_webserver,
|
||||
stop_webserver,
|
||||
start_apiserver,
|
||||
stop_apiserver,
|
||||
)
|
||||
@ -70,17 +68,24 @@ try:
|
||||
add_command_alias(model_cli_group, name="model", parent_group=cli)
|
||||
add_command_alias(start_model_controller, name="controller", parent_group=start)
|
||||
add_command_alias(start_model_worker, name="worker", parent_group=start)
|
||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||
add_command_alias(start_apiserver, name="apiserver", parent_group=start)
|
||||
|
||||
add_command_alias(stop_model_controller, name="controller", parent_group=stop)
|
||||
add_command_alias(stop_model_worker, name="worker", parent_group=stop)
|
||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||
add_command_alias(stop_apiserver, name="apiserver", parent_group=stop)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt model command line tool failed: {e}")
|
||||
|
||||
try:
|
||||
from pilot.server._cli import start_webserver, stop_webserver
|
||||
|
||||
add_command_alias(start_webserver, name="webserver", parent_group=start)
|
||||
add_command_alias(stop_webserver, name="webserver", parent_group=stop)
|
||||
|
||||
except ImportError as e:
|
||||
logging.warning(f"Integrating dbgpt webserver command line tool failed: {e}")
|
||||
|
||||
try:
|
||||
from pilot.server.knowledge._cli.knowledge_cli import knowledge_cli_group
|
||||
|
||||
|
32
pilot/server/_cli.py
Normal file
32
pilot/server/_cli.py
Normal file
@ -0,0 +1,32 @@
|
||||
import click
|
||||
import os
|
||||
from pilot.server.base import WebWerverParameters
|
||||
from pilot.configs.model_config import LOGDIR
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
from pilot.utils.command_utils import _run_current_with_daemon, _stop_service
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
@EnvArgumentParser.create_click_option(WebWerverParameters)
|
||||
def start_webserver(**kwargs):
|
||||
"""Start webserver(dbgpt_server.py)"""
|
||||
if kwargs["daemon"]:
|
||||
log_file = os.path.join(LOGDIR, "webserver_uvicorn.log")
|
||||
_run_current_with_daemon("WebServer", log_file)
|
||||
else:
|
||||
from pilot.server.dbgpt_server import run_webserver
|
||||
|
||||
run_webserver(WebWerverParameters(**kwargs))
|
||||
|
||||
|
||||
@click.command(name="webserver")
|
||||
@click.option(
|
||||
"--port",
|
||||
type=int,
|
||||
default=None,
|
||||
required=False,
|
||||
help=("The port to stop"),
|
||||
)
|
||||
def stop_webserver(port: int):
|
||||
"""Stop webserver(dbgpt_server.py)"""
|
||||
_stop_service("webserver", "WebServer", port=port)
|
@ -2,13 +2,12 @@ import signal
|
||||
import os
|
||||
import threading
|
||||
import sys
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.configs.config import Config
|
||||
from pilot.utils.parameter_utils import BaseParameters
|
||||
|
||||
from pilot.common.plugins import scan_plugins
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@ -20,12 +19,18 @@ def signal_handler(sig, frame):
|
||||
|
||||
|
||||
def async_db_summery():
|
||||
from pilot.summary.db_summary_client import DBSummaryClient
|
||||
|
||||
client = DBSummaryClient()
|
||||
thread = threading.Thread(target=client.init_db_summary)
|
||||
thread.start()
|
||||
|
||||
|
||||
def server_init(args):
|
||||
from pilot.commands.command_mange import CommandRegistry
|
||||
from pilot.connections.manages.connection_manager import ConnectManager
|
||||
from pilot.common.plugins import scan_plugins
|
||||
|
||||
# logger.info(f"args: {args}")
|
||||
|
||||
# init config
|
||||
@ -63,3 +68,38 @@ def server_init(args):
|
||||
for command in command_disply_commands:
|
||||
command_disply_registry.import_commands(command)
|
||||
cfg.command_disply = command_disply_registry
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebWerverParameters(BaseParameters):
|
||||
host: Optional[str] = field(
|
||||
default="0.0.0.0", metadata={"help": "Webserver deploy host"}
|
||||
)
|
||||
port: Optional[int] = field(
|
||||
default=5000, metadata={"help": "Webserver deploy port"}
|
||||
)
|
||||
daemon: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Run Webserver in background"}
|
||||
)
|
||||
share: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to create a publicly shareable link for the interface. Creates an SSH tunnel to make your UI accessible from anywhere. "
|
||||
},
|
||||
)
|
||||
log_level: Optional[str] = field(
|
||||
default="INFO",
|
||||
metadata={
|
||||
"help": "Logging level",
|
||||
"valid_values": [
|
||||
"FATAL",
|
||||
"ERROR",
|
||||
"WARNING",
|
||||
"WARNING",
|
||||
"INFO",
|
||||
"DEBUG",
|
||||
"NOTSET",
|
||||
],
|
||||
},
|
||||
)
|
||||
light: Optional[bool] = field(default=False, metadata={"help": "enable light mode"})
|
||||
|
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import argparse
|
||||
import sys
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.append(ROOT_PATH)
|
||||
@ -9,7 +9,7 @@ import signal
|
||||
from pilot.configs.config import Config
|
||||
from pilot.configs.model_config import LLM_MODEL_CONFIG
|
||||
|
||||
from pilot.server.base import server_init
|
||||
from pilot.server.base import server_init, WebWerverParameters
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi import FastAPI, applications
|
||||
@ -24,7 +24,7 @@ from pilot.openapi.base import validation_exception_handler
|
||||
from pilot.openapi.api_v1.editor.api_editor_v1 import router as api_editor_route_v1
|
||||
from pilot.commands.disply_type.show_chart_gen import static_message_img_path
|
||||
from pilot.model.worker.manager import initialize_worker_manager_in_client
|
||||
from pilot.utils.utils import setup_logging
|
||||
from pilot.utils.utils import setup_logging, logging_str_to_uvicorn_level
|
||||
|
||||
static_file_path = os.path.join(os.getcwd(), "server/static")
|
||||
|
||||
@ -84,33 +84,24 @@ def mount_static_files(app):
|
||||
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_list_mode", type=str, default="once", choices=["once", "reload"]
|
||||
)
|
||||
|
||||
# old version server config
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=5000)
|
||||
parser.add_argument("--concurrency-count", type=int, default=10)
|
||||
parser.add_argument("--share", default=False, action="store_true")
|
||||
parser.add_argument("--log-level", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"-light",
|
||||
"--light",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="enable light mode",
|
||||
)
|
||||
def initialize_app(param: WebWerverParameters = None, args: List[str] = None):
|
||||
"""Initialize app
|
||||
If you use gunicorn as a process manager, initialize_app can be invoke in `on_starting` hook.
|
||||
"""
|
||||
if not param:
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
# init server config
|
||||
args = parser.parse_args()
|
||||
setup_logging(logging_level=args.log_level)
|
||||
server_init(args)
|
||||
parser: argparse.ArgumentParser = EnvArgumentParser.create_argparse_option(
|
||||
WebWerverParameters
|
||||
)
|
||||
param = WebWerverParameters(**vars(parser.parse_args(args=args)))
|
||||
|
||||
setup_logging(logging_level=param.log_level)
|
||||
server_init(param)
|
||||
|
||||
model_path = LLM_MODEL_CONFIG[CFG.LLM_MODEL]
|
||||
if not args.light:
|
||||
if not param.light:
|
||||
print("Model Unified Deployment Mode!")
|
||||
initialize_worker_manager_in_client(
|
||||
app=app, model_name=CFG.LLM_MODEL, model_path=model_path
|
||||
@ -129,7 +120,25 @@ if __name__ == "__main__":
|
||||
CFG.SERVER_LIGHT_MODE = True
|
||||
|
||||
mount_static_files(app)
|
||||
return param
|
||||
|
||||
|
||||
def run_uvicorn(param: WebWerverParameters):
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port, log_level="info")
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=param.host,
|
||||
port=param.port,
|
||||
log_level=logging_str_to_uvicorn_level(param.log_level),
|
||||
)
|
||||
signal.signal(signal.SIGINT, signal_handler())
|
||||
|
||||
|
||||
def run_webserver(param: WebWerverParameters = None):
|
||||
param = initialize_app(param)
|
||||
run_uvicorn(param)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_webserver()
|
||||
|
91
pilot/utils/command_utils.py
Normal file
91
pilot/utils/command_utils.py
Normal file
@ -0,0 +1,91 @@
|
||||
import sys
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Dict
|
||||
import psutil
|
||||
import platform
|
||||
|
||||
|
||||
def _run_current_with_daemon(name: str, log_file: str):
|
||||
# Get all arguments except for --daemon
|
||||
args = [arg for arg in sys.argv if arg != "--daemon" and arg != "-d"]
|
||||
daemon_cmd = [sys.executable] + args
|
||||
daemon_cmd = " ".join(daemon_cmd)
|
||||
daemon_cmd += f" > {log_file} 2>&1"
|
||||
|
||||
# Check the platform and set the appropriate flags or functions
|
||||
if platform.system() == "Windows":
|
||||
process = subprocess.Popen(
|
||||
daemon_cmd,
|
||||
creationflags=subprocess.CREATE_NEW_PROCESS_GROUP,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
shell=True,
|
||||
)
|
||||
else: # macOS, Linux, and other Unix-like systems
|
||||
process = subprocess.Popen(
|
||||
daemon_cmd,
|
||||
preexec_fn=os.setsid,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
shell=True,
|
||||
)
|
||||
|
||||
print(f"Started {name} in background with pid: {process.pid}")
|
||||
|
||||
|
||||
def _run_current_with_gunicorn(app: str, config_path: str, kwargs: Dict):
|
||||
try:
|
||||
import gunicorn
|
||||
except ImportError as e:
|
||||
raise ValueError(
|
||||
"Could not import python package: gunicorn"
|
||||
"Daemon mode need install gunicorn, please install `pip install gunicorn`"
|
||||
) from e
|
||||
|
||||
from pilot.utils.parameter_utils import EnvArgumentParser
|
||||
|
||||
env_to_app = {}
|
||||
env_to_app.update(os.environ)
|
||||
app_env = EnvArgumentParser._kwargs_to_env_key_value(kwargs)
|
||||
env_to_app.update(app_env)
|
||||
cmd = f"uvicorn {app} --host 0.0.0.0 --port 5000"
|
||||
if platform.system() == "Windows":
|
||||
raise Exception("Not support on windows")
|
||||
else: # macOS, Linux, and other Unix-like systems
|
||||
process = subprocess.Popen(cmd, shell=True, env=env_to_app)
|
||||
print(f"Started {app} with gunicorn in background with pid: {process.pid}")
|
||||
|
||||
|
||||
def _stop_service(
|
||||
key: str, fullname: str, service_keys: List[str] = None, port: int = None
|
||||
):
|
||||
if not service_keys:
|
||||
service_keys = [sys.argv[0], "start", key]
|
||||
not_found = True
|
||||
for process in psutil.process_iter(attrs=["pid", "connections", "cmdline"]):
|
||||
try:
|
||||
cmdline = " ".join(process.info["cmdline"])
|
||||
|
||||
# Check if all key fragments are in the cmdline
|
||||
if all(fragment in cmdline for fragment in service_keys):
|
||||
if port:
|
||||
for conn in process.info["connections"]:
|
||||
if (
|
||||
conn.status == psutil.CONN_LISTEN
|
||||
and conn.laddr.port == port
|
||||
):
|
||||
psutil.Process(process.info["pid"]).terminate()
|
||||
print(
|
||||
f"Terminated the {fullname} with PID: {process.info['pid']} listening on port: {port}"
|
||||
)
|
||||
not_found = False
|
||||
else:
|
||||
psutil.Process(process.info["pid"]).terminate()
|
||||
print(f"Terminated the {fullname} with PID: {process.info['pid']}")
|
||||
not_found = False
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
continue
|
||||
|
||||
if not_found:
|
||||
print(f"{fullname} process not found.")
|
@ -1,7 +1,8 @@
|
||||
import argparse
|
||||
import os
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import Any, List, Optional, Type, Union, Callable
|
||||
from typing import Any, List, Optional, Type, Union, Callable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -60,7 +61,7 @@ class BaseParameters:
|
||||
f"\n\n=========================== {class_name} ===========================\n"
|
||||
]
|
||||
for field_info in fields(self):
|
||||
value = getattr(self, field_info.name)
|
||||
value = _get_simple_privacy_field_value(self, field_info)
|
||||
parameters.append(f"{field_info.name}: {value}")
|
||||
parameters.append(
|
||||
"\n======================================================================\n\n"
|
||||
@ -68,6 +69,55 @@ class BaseParameters:
|
||||
return "\n".join(parameters)
|
||||
|
||||
|
||||
def _get_simple_privacy_field_value(obj, field_info):
|
||||
"""Retrieve the value of a field from a dataclass instance, applying privacy rules if necessary.
|
||||
|
||||
This function reads the metadata of a field to check if it's tagged with 'privacy'.
|
||||
If the 'privacy' tag is present, then it modifies the value based on its type
|
||||
for privacy concerns:
|
||||
- int: returns -999
|
||||
- float: returns -999.0
|
||||
- bool: returns False
|
||||
- str: if length > 5, masks the middle part and returns first and last char;
|
||||
otherwise, returns "******"
|
||||
|
||||
Parameters:
|
||||
- obj: The dataclass instance.
|
||||
- field_info: A Field object that contains information about the dataclass field.
|
||||
|
||||
Returns:
|
||||
The original or modified value of the field based on the privacy rules.
|
||||
|
||||
Example usage:
|
||||
@dataclass
|
||||
class Person:
|
||||
name: str
|
||||
age: int
|
||||
ssn: str = field(metadata={"tags": "privacy"})
|
||||
p = Person("Alice", 30, "123-45-6789")
|
||||
print(_get_simple_privacy_field_value(p, Person.ssn)) # A******9
|
||||
"""
|
||||
tags = field_info.metadata.get("tags")
|
||||
tags = [] if not tags else tags.split(",")
|
||||
is_privacy = False
|
||||
if tags and "privacy" in tags:
|
||||
is_privacy = True
|
||||
value = getattr(obj, field_info.name)
|
||||
if not is_privacy or not value:
|
||||
return value
|
||||
field_type = EnvArgumentParser._get_argparse_type(field_info.type)
|
||||
if field_type is int:
|
||||
return -999
|
||||
if field_type is float:
|
||||
return -999.0
|
||||
if field_type is bool:
|
||||
return False
|
||||
# str
|
||||
if len(value) > 5:
|
||||
return value[0] + "******" + value[-1]
|
||||
return "******"
|
||||
|
||||
|
||||
def _genenv_ignoring_key_case(env_key: str, env_prefix: str = None, default_value=None):
|
||||
"""Get the value from the environment variable, ignoring the case of the key"""
|
||||
if env_prefix:
|
||||
@ -115,24 +165,11 @@ class EnvArgumentParser:
|
||||
if not env_var_value:
|
||||
env_var_value = kwargs.get(field.name)
|
||||
|
||||
print(f"env_var_value: {env_var_value} for {field.name}")
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
if env_var_value:
|
||||
argument_kwargs["default"] = env_var_value
|
||||
argument_kwargs["required"] = False
|
||||
|
||||
parser.add_argument(f"--{field.name}", **argument_kwargs)
|
||||
EnvArgumentParser._build_single_argparse_option(
|
||||
parser, field, env_var_value
|
||||
)
|
||||
|
||||
# Parse the command-line arguments
|
||||
cmd_args, cmd_argv = parser.parse_known_args(args=command_args)
|
||||
@ -148,7 +185,7 @@ class EnvArgumentParser:
|
||||
return dataclass_type(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
||||
def _create_arg_parser(dataclass_type: Type) -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description=dataclass_type.__doc__)
|
||||
for field in fields(dataclass_type):
|
||||
help_text = field.metadata.get("help", "")
|
||||
@ -173,11 +210,6 @@ class EnvArgumentParser:
|
||||
import functools
|
||||
from collections import OrderedDict
|
||||
|
||||
# TODO dynamic configuration
|
||||
# pre_args = _SimpleArgParser('model_name', 'model_path')
|
||||
# pre_args.parse()
|
||||
# print(pre_args)
|
||||
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
@ -225,6 +257,48 @@ class EnvArgumentParser:
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def create_argparse_option(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
) -> argparse.ArgumentParser:
|
||||
combined_fields = _merge_dataclass_types(
|
||||
*dataclass_types, _dynamic_factory=_dynamic_factory
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
for _, field in reversed(combined_fields.items()):
|
||||
EnvArgumentParser._build_single_argparse_option(parser, field)
|
||||
return parser
|
||||
|
||||
@staticmethod
|
||||
def _build_single_argparse_option(
|
||||
parser: argparse.ArgumentParser, field, default_value=None
|
||||
):
|
||||
# Add a command-line argument for this field
|
||||
help_text = field.metadata.get("help", "")
|
||||
valid_values = field.metadata.get("valid_values", None)
|
||||
short_name = field.metadata.get("short", None)
|
||||
argument_kwargs = {
|
||||
"type": EnvArgumentParser._get_argparse_type(field.type),
|
||||
"help": help_text,
|
||||
"choices": valid_values,
|
||||
"required": EnvArgumentParser._is_require_type(field.type),
|
||||
}
|
||||
if field.default != MISSING:
|
||||
argument_kwargs["default"] = field.default
|
||||
argument_kwargs["required"] = False
|
||||
if default_value:
|
||||
argument_kwargs["default"] = default_value
|
||||
argument_kwargs["required"] = False
|
||||
if field.type is bool or field.type == Optional[bool]:
|
||||
argument_kwargs["action"] = "store_true"
|
||||
del argument_kwargs["type"]
|
||||
del argument_kwargs["choices"]
|
||||
names = []
|
||||
if short_name:
|
||||
names.append(f"-{short_name}")
|
||||
names.append(f"--{field.name}")
|
||||
parser.add_argument(*names, **argument_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _get_argparse_type(field_type: Type) -> Type:
|
||||
# Return the appropriate type for argparse to use based on the field type
|
||||
@ -255,6 +329,42 @@ class EnvArgumentParser:
|
||||
def _is_require_type(field_type: Type) -> str:
|
||||
return field_type not in [Optional[int], Optional[float], Optional[bool]]
|
||||
|
||||
@staticmethod
|
||||
def _kwargs_to_env_key_value(
|
||||
kwargs: Dict, prefix: str = "__dbgpt_gunicorn__env_prefix__"
|
||||
) -> Dict[str, str]:
|
||||
return {prefix + k: str(v) for k, v in kwargs.items()}
|
||||
|
||||
@staticmethod
|
||||
def _read_env_key_value(
|
||||
prefix: str = "__dbgpt_gunicorn__env_prefix__",
|
||||
) -> List[str]:
|
||||
env_args = []
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith(prefix):
|
||||
arg_key = "--" + key.replace(prefix, "")
|
||||
if value.lower() in ["true", "1"]:
|
||||
# Flag args
|
||||
env_args.append(arg_key)
|
||||
elif not value.lower() in ["false", "0"]:
|
||||
env_args.extend([arg_key, value])
|
||||
return env_args
|
||||
|
||||
|
||||
def _merge_dataclass_types(
|
||||
*dataclass_types: Type, _dynamic_factory: Callable[[None], List[Type]] = None
|
||||
) -> OrderedDict:
|
||||
combined_fields = OrderedDict()
|
||||
if _dynamic_factory:
|
||||
_types = _dynamic_factory()
|
||||
if _types:
|
||||
dataclass_types = list(_types)
|
||||
for dataclass_type in dataclass_types:
|
||||
for field in fields(dataclass_type):
|
||||
if field.name not in combined_fields:
|
||||
combined_fields[field.name] = field
|
||||
return combined_fields
|
||||
|
||||
|
||||
def _get_parameter_descriptions(dataclass_type: Type) -> List[ParameterDescription]:
|
||||
descriptions = []
|
||||
@ -297,7 +407,7 @@ class _SimpleArgParser:
|
||||
self.params[prev_arg] = None
|
||||
|
||||
def _get_param(self, key):
|
||||
return self.params.get(key.replace("_", "-"), None)
|
||||
return self.params.get(key.replace("_", "-")) or self.params.get(key)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return self._get_param(item)
|
||||
|
@ -161,3 +161,15 @@ def get_or_create_event_loop() -> asyncio.BaseEventLoop:
|
||||
raise e
|
||||
logging.warning("Cant not get running event loop, create new event loop now")
|
||||
return asyncio.get_event_loop_policy().get_event_loop()
|
||||
|
||||
|
||||
def logging_str_to_uvicorn_level(log_level_str):
|
||||
level_str_mapping = {
|
||||
"CRITICAL": "critical",
|
||||
"ERROR": "error",
|
||||
"WARNING": "warning",
|
||||
"INFO": "info",
|
||||
"DEBUG": "debug",
|
||||
"NOTSET": "info",
|
||||
}
|
||||
return level_str_mapping.get(log_level_str.upper(), "info")
|
||||
|
Loading…
Reference in New Issue
Block a user