mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-27 13:57:46 +00:00
629 lines
20 KiB
Python
629 lines
20 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Dict, Optional, Tuple, Union
|
|
|
|
from dbgpt.util.parameter_utils import BaseParameters, BaseServerParameters
|
|
|
|
|
|
class WorkerType(str, Enum):
|
|
LLM = "llm"
|
|
TEXT2VEC = "text2vec"
|
|
|
|
@staticmethod
|
|
def values():
|
|
return [item.value for item in WorkerType]
|
|
|
|
@staticmethod
|
|
def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str:
|
|
"""Generate worker key from worker name and worker type
|
|
|
|
Args:
|
|
worker_name (str): Worker name(eg., chatglm2-6b)
|
|
worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`])
|
|
|
|
Returns:
|
|
str: Generated worker key
|
|
"""
|
|
if "@" in worker_name:
|
|
raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}")
|
|
if isinstance(worker_type, WorkerType):
|
|
worker_type = worker_type.value
|
|
return f"{worker_name}@{worker_type}"
|
|
|
|
@staticmethod
|
|
def parse_worker_key(worker_key: str) -> Tuple[str, str]:
|
|
"""Parse worker name and worker type from worker key
|
|
|
|
Args:
|
|
worker_key (str): Worker key generated by [`WorkerType.to_worker_key`]
|
|
|
|
Returns:
|
|
Tuple[str, str]: Worker name and worker type
|
|
"""
|
|
return tuple(worker_key.split("@"))
|
|
|
|
|
|
@dataclass
|
|
class ModelControllerParameters(BaseServerParameters):
|
|
port: Optional[int] = field(
|
|
default=8000, metadata={"help": "Model Controller deploy port"}
|
|
)
|
|
registry_type: Optional[str] = field(
|
|
default="embedded",
|
|
metadata={
|
|
"help": "Registry type: embedded, database...",
|
|
"valid_values": ["embedded", "database"],
|
|
},
|
|
)
|
|
registry_db_type: Optional[str] = field(
|
|
default="mysql",
|
|
metadata={
|
|
"help": "Registry database type, now only support sqlite and mysql, it is "
|
|
"valid when registry_type is database",
|
|
"valid_values": ["mysql", "sqlite"],
|
|
},
|
|
)
|
|
registry_db_name: Optional[str] = field(
|
|
default="dbgpt",
|
|
metadata={
|
|
"help": "Registry database name, just for database, it is valid when "
|
|
"registry_type is database, please set to full database path for sqlite"
|
|
},
|
|
)
|
|
registry_db_host: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Registry database host, just for database, it is valid when "
|
|
"registry_type is database"
|
|
},
|
|
)
|
|
registry_db_port: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Registry database port, just for database, it is valid when "
|
|
"registry_type is database"
|
|
},
|
|
)
|
|
registry_db_user: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Registry database user, just for database, it is valid when "
|
|
"registry_type is database"
|
|
},
|
|
)
|
|
registry_db_password: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Registry database password, just for database, it is valid when "
|
|
"registry_type is database. We recommend to use environment variable to "
|
|
"store password, you can set it in your environment variable like "
|
|
"export CONTROLLER_REGISTRY_DB_PASSWORD='your_password'"
|
|
},
|
|
)
|
|
registry_db_pool_size: Optional[int] = field(
|
|
default=5,
|
|
metadata={
|
|
"help": "Registry database pool size, just for database, it is valid when "
|
|
"registry_type is database"
|
|
},
|
|
)
|
|
registry_db_max_overflow: Optional[int] = field(
|
|
default=10,
|
|
metadata={
|
|
"help": "Registry database max overflow, just for database, it is valid "
|
|
"when registry_type is database"
|
|
},
|
|
)
|
|
|
|
heartbeat_interval_secs: Optional[int] = field(
|
|
default=20, metadata={"help": "The interval for checking heartbeats (seconds)"}
|
|
)
|
|
heartbeat_timeout_secs: Optional[int] = field(
|
|
default=60,
|
|
metadata={
|
|
"help": "The timeout for checking heartbeats (seconds), it will be set "
|
|
"unhealthy if the worker is not responding in this time"
|
|
},
|
|
)
|
|
|
|
log_file: Optional[str] = field(
|
|
default="dbgpt_model_controller.log",
|
|
metadata={
|
|
"help": "The filename to store log",
|
|
},
|
|
)
|
|
tracer_file: Optional[str] = field(
|
|
default="dbgpt_model_controller_tracer.jsonl",
|
|
metadata={
|
|
"help": "The filename to store tracer span records",
|
|
},
|
|
)
|
|
tracer_storage_cls: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The storage class to storage tracer span records",
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ModelAPIServerParameters(BaseServerParameters):
|
|
port: Optional[int] = field(
|
|
default=8100, metadata={"help": "Model API server deploy port"}
|
|
)
|
|
controller_addr: Optional[str] = field(
|
|
default="http://127.0.0.1:8000",
|
|
metadata={"help": "The Model controller address to connect"},
|
|
)
|
|
|
|
api_keys: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Optional list of comma separated API keys"},
|
|
)
|
|
embedding_batch_size: Optional[int] = field(
|
|
default=None, metadata={"help": "Embedding batch size"}
|
|
)
|
|
|
|
log_file: Optional[str] = field(
|
|
default="dbgpt_model_apiserver.log",
|
|
metadata={
|
|
"help": "The filename to store log",
|
|
},
|
|
)
|
|
tracer_file: Optional[str] = field(
|
|
default="dbgpt_model_apiserver_tracer.jsonl",
|
|
metadata={
|
|
"help": "The filename to store tracer span records",
|
|
},
|
|
)
|
|
tracer_storage_cls: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The storage class to storage tracer span records",
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BaseModelParameters(BaseParameters):
|
|
model_name: str = field(metadata={"help": "Model name", "tags": "fixed"})
|
|
model_path: str = field(metadata={"help": "Model path", "tags": "fixed"})
|
|
|
|
|
|
@dataclass
|
|
class ModelWorkerParameters(BaseServerParameters, BaseModelParameters):
|
|
worker_type: Optional[str] = field(
|
|
default=None,
|
|
metadata={"valid_values": WorkerType.values(), "help": "Worker type"},
|
|
)
|
|
model_alias: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "model alias"},
|
|
)
|
|
worker_class: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Model worker class, dbgpt.model.cluster.DefaultModelWorker"},
|
|
)
|
|
model_type: Optional[str] = field(
|
|
default="huggingface",
|
|
metadata={
|
|
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
|
"tags": "fixed",
|
|
},
|
|
)
|
|
|
|
port: Optional[int] = field(
|
|
default=8001, metadata={"help": "Model worker deploy port"}
|
|
)
|
|
limit_model_concurrency: Optional[int] = field(
|
|
default=5, metadata={"help": "Model concurrency limit"}
|
|
)
|
|
standalone: Optional[bool] = field(
|
|
default=False,
|
|
metadata={"help": "Standalone mode. If True, embedded Run ModelController"},
|
|
)
|
|
register: Optional[bool] = field(
|
|
default=True, metadata={"help": "Register current worker to model controller"}
|
|
)
|
|
worker_register_host: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The ip address of current worker to register to ModelController. "
|
|
"If None, the address is automatically determined"
|
|
},
|
|
)
|
|
controller_addr: Optional[str] = field(
|
|
default=None, metadata={"help": "The Model controller address to register"}
|
|
)
|
|
send_heartbeat: Optional[bool] = field(
|
|
default=True, metadata={"help": "Send heartbeat to model controller"}
|
|
)
|
|
heartbeat_interval: Optional[int] = field(
|
|
default=20, metadata={"help": "The interval for sending heartbeats (seconds)"}
|
|
)
|
|
|
|
log_file: Optional[str] = field(
|
|
default="dbgpt_model_worker_manager.log",
|
|
metadata={
|
|
"help": "The filename to store log",
|
|
},
|
|
)
|
|
tracer_file: Optional[str] = field(
|
|
default="dbgpt_model_worker_manager_tracer.jsonl",
|
|
metadata={
|
|
"help": "The filename to store tracer span records",
|
|
},
|
|
)
|
|
tracer_storage_cls: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The storage class to storage tracer span records",
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class BaseEmbeddingModelParameters(BaseModelParameters):
|
|
def build_kwargs(self, **kwargs) -> Dict:
|
|
pass
|
|
|
|
def is_rerank_model(self) -> bool:
|
|
"""Check if the model is a rerank model"""
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class EmbeddingModelParameters(BaseEmbeddingModelParameters):
|
|
device: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Device to run model. If None, the device is automatically determined"
|
|
},
|
|
)
|
|
|
|
normalize_embeddings: Optional[bool] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Determines whether the model's embeddings should be normalized."
|
|
},
|
|
)
|
|
|
|
rerank: Optional[bool] = field(
|
|
default=False, metadata={"help": "Whether the model is a rerank model"}
|
|
)
|
|
|
|
max_length: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Max length for input sequences. Longer sequences will be "
|
|
"truncated. If None, max length of the model will be used, just for rerank"
|
|
" model now."
|
|
},
|
|
)
|
|
|
|
def build_kwargs(self, **kwargs) -> Dict:
|
|
model_kwargs, encode_kwargs = None, None
|
|
if self.device:
|
|
model_kwargs = {"device": self.device}
|
|
if self.normalize_embeddings:
|
|
encode_kwargs = {"normalize_embeddings": self.normalize_embeddings}
|
|
if model_kwargs:
|
|
kwargs["model_kwargs"] = model_kwargs
|
|
if self.is_rerank_model():
|
|
kwargs["max_length"] = self.max_length
|
|
if encode_kwargs:
|
|
kwargs["encode_kwargs"] = encode_kwargs
|
|
return kwargs
|
|
|
|
def is_rerank_model(self) -> bool:
|
|
"""Check if the model is a rerank model"""
|
|
return self.rerank
|
|
|
|
|
|
@dataclass
|
|
class ModelParameters(BaseModelParameters):
|
|
device: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Device to run model. If None, the device is automatically determined"
|
|
},
|
|
)
|
|
model_type: Optional[str] = field(
|
|
default="huggingface",
|
|
metadata={
|
|
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
|
"tags": "fixed",
|
|
},
|
|
)
|
|
prompt_template: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": f"Prompt template. If None, the prompt template is automatically "
|
|
f"determined from model path"
|
|
},
|
|
)
|
|
max_context_size: Optional[int] = field(
|
|
default=4096, metadata={"help": "Maximum context size"}
|
|
)
|
|
|
|
num_gpus: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The number of gpus you expect to use, if it is empty, use all of them as much as possible"
|
|
},
|
|
)
|
|
max_gpu_memory: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The maximum memory limit of each GPU, only valid in multi-GPU configuration"
|
|
},
|
|
)
|
|
cpu_offloading: Optional[bool] = field(
|
|
default=False, metadata={"help": "CPU offloading"}
|
|
)
|
|
load_8bit: Optional[bool] = field(
|
|
default=False, metadata={"help": "8-bit quantization"}
|
|
)
|
|
load_4bit: Optional[bool] = field(
|
|
default=False, metadata={"help": "4-bit quantization"}
|
|
)
|
|
quant_type: Optional[str] = field(
|
|
default="nf4",
|
|
metadata={
|
|
"valid_values": ["nf4", "fp4"],
|
|
"help": "Quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float), only valid when load_4bit=True",
|
|
},
|
|
)
|
|
use_double_quant: Optional[bool] = field(
|
|
default=True,
|
|
metadata={"help": "Nested quantization, only valid when load_4bit=True"},
|
|
)
|
|
compute_dtype: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"valid_values": ["bfloat16", "float16", "float32"],
|
|
"help": "Model compute type",
|
|
},
|
|
)
|
|
trust_remote_code: Optional[bool] = field(
|
|
default=True, metadata={"help": "Trust remote code"}
|
|
)
|
|
verbose: Optional[bool] = field(
|
|
default=False, metadata={"help": "Show verbose output."}
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class LlamaCppModelParameters(ModelParameters):
|
|
seed: Optional[int] = field(
|
|
default=-1, metadata={"help": "Random seed for llama-cpp models. -1 for random"}
|
|
)
|
|
n_threads: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Number of threads to use. If None, the number of threads is automatically determined"
|
|
},
|
|
)
|
|
n_batch: Optional[int] = field(
|
|
default=512,
|
|
metadata={
|
|
"help": "Maximum number of prompt tokens to batch together when calling llama_eval"
|
|
},
|
|
)
|
|
n_gpu_layers: Optional[int] = field(
|
|
default=1000000000,
|
|
metadata={
|
|
"help": "Number of layers to offload to the GPU, Set this to 1000000000 to offload all layers to the GPU."
|
|
},
|
|
)
|
|
n_gqa: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "Grouped-query attention. Must be 8 for llama-2 70b."},
|
|
)
|
|
rms_norm_eps: Optional[float] = field(
|
|
default=5e-06, metadata={"help": "5e-6 is a good value for llama-2 models."}
|
|
)
|
|
cache_capacity: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. "
|
|
},
|
|
)
|
|
prefer_cpu: Optional[bool] = field(
|
|
default=False,
|
|
metadata={
|
|
"help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured."
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProxyModelParameters(BaseModelParameters):
|
|
proxy_server_url: str = field(
|
|
metadata={
|
|
"help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions"
|
|
},
|
|
)
|
|
|
|
proxy_api_key: str = field(
|
|
metadata={"tags": "privacy", "help": "The api key of current proxy LLM"},
|
|
)
|
|
|
|
proxy_api_base: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first"
|
|
},
|
|
)
|
|
|
|
proxy_api_app_id: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The app id for current proxy LLM(Just for spark proxy LLM now)."
|
|
},
|
|
)
|
|
|
|
proxy_api_secret: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The app secret for current proxy LLM(Just for spark proxy LLM now)."
|
|
},
|
|
)
|
|
|
|
proxy_api_type: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure"
|
|
},
|
|
)
|
|
|
|
proxy_api_version: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "The api version of current proxy the current model"},
|
|
)
|
|
|
|
http_proxy: Optional[str] = field(
|
|
default=os.environ.get("http_proxy") or os.environ.get("https_proxy"),
|
|
metadata={"help": "The http or https proxy to use openai"},
|
|
)
|
|
|
|
proxyllm_backend: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The model name actually pass to current proxy server url, such "
|
|
"as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on"
|
|
},
|
|
)
|
|
model_type: Optional[str] = field(
|
|
default="proxy",
|
|
metadata={
|
|
"help": "Model type: huggingface, llama.cpp, proxy and vllm",
|
|
"tags": "fixed",
|
|
},
|
|
)
|
|
device: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "Device to run model. If None, the device is automatically "
|
|
"determined"
|
|
},
|
|
)
|
|
prompt_template: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": f"Prompt template. If None, the prompt template is automatically "
|
|
f"determined from model path"
|
|
},
|
|
)
|
|
max_context_size: Optional[int] = field(
|
|
default=4096, metadata={"help": "Maximum context size"}
|
|
)
|
|
llm_client_class: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The class name of llm client, such as "
|
|
"dbgpt.model.proxy.llms.proxy_model.ProxyModel"
|
|
},
|
|
)
|
|
|
|
def __post_init__(self):
|
|
if not self.proxy_server_url and self.proxy_api_base:
|
|
self.proxy_server_url = f"{self.proxy_api_base}/chat/completions"
|
|
|
|
|
|
@dataclass
|
|
class ProxyEmbeddingParameters(BaseEmbeddingModelParameters):
|
|
proxy_server_url: str = field(
|
|
metadata={
|
|
"help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1"
|
|
},
|
|
)
|
|
proxy_api_key: str = field(
|
|
metadata={
|
|
"tags": "privacy",
|
|
"help": "The api key of the current embedding model(OPENAI_API_KEY)",
|
|
},
|
|
)
|
|
device: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Device to run model. Not working for proxy embedding model"},
|
|
)
|
|
proxy_api_type: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure"
|
|
},
|
|
)
|
|
proxy_api_secret: str = field(
|
|
default=None,
|
|
metadata={
|
|
"tags": "privacy",
|
|
"help": "The api secret of the current embedding model(OPENAI_API_SECRET)",
|
|
},
|
|
)
|
|
proxy_api_version: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)"
|
|
},
|
|
)
|
|
proxy_backend: Optional[str] = field(
|
|
default="text-embedding-ada-002",
|
|
metadata={
|
|
"help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002"
|
|
},
|
|
)
|
|
|
|
proxy_deployment: Optional[str] = field(
|
|
default="text-embedding-ada-002",
|
|
metadata={"help": "Tto support Azure OpenAI Service custom deployment names"},
|
|
)
|
|
|
|
rerank: Optional[bool] = field(
|
|
default=False, metadata={"help": "Whether the model is a rerank model"}
|
|
)
|
|
|
|
def build_kwargs(self, **kwargs) -> Dict:
|
|
params = {
|
|
"openai_api_base": self.proxy_server_url,
|
|
"openai_api_key": self.proxy_api_key,
|
|
"openai_api_type": self.proxy_api_type if self.proxy_api_type else None,
|
|
"openai_api_version": (
|
|
self.proxy_api_version if self.proxy_api_version else None
|
|
),
|
|
"model": self.proxy_backend,
|
|
"deployment": (
|
|
self.proxy_deployment if self.proxy_deployment else self.proxy_backend
|
|
),
|
|
}
|
|
for k, v in kwargs:
|
|
params[k] = v
|
|
return params
|
|
|
|
def is_rerank_model(self) -> bool:
|
|
"""Check if the model is a rerank model"""
|
|
return self.rerank
|
|
|
|
|
|
_EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = {
|
|
ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi",
|
|
}
|
|
|
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {}
|
|
|
|
|
|
def _update_embedding_config():
|
|
global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG
|
|
for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items():
|
|
models = [m.strip() for m in models.split(",")]
|
|
for model in models:
|
|
if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG:
|
|
EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls
|
|
|
|
|
|
_update_embedding_config()
|