#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os from dataclasses import dataclass, field from enum import Enum from typing import Dict, Optional, Tuple, Union from dbgpt.util.parameter_utils import BaseParameters, BaseServerParameters class WorkerType(str, Enum): LLM = "llm" TEXT2VEC = "text2vec" @staticmethod def values(): return [item.value for item in WorkerType] @staticmethod def to_worker_key(worker_name, worker_type: Union[str, "WorkerType"]) -> str: """Generate worker key from worker name and worker type Args: worker_name (str): Worker name(eg., chatglm2-6b) worker_type (Union[str, "WorkerType"]): Worker type(eg., 'llm', or [`WorkerType.LLM`]) Returns: str: Generated worker key """ if "@" in worker_name: raise ValueError(f"Invaild symbol '@' in your worker name {worker_name}") if isinstance(worker_type, WorkerType): worker_type = worker_type.value return f"{worker_name}@{worker_type}" @staticmethod def parse_worker_key(worker_key: str) -> Tuple[str, str]: """Parse worker name and worker type from worker key Args: worker_key (str): Worker key generated by [`WorkerType.to_worker_key`] Returns: Tuple[str, str]: Worker name and worker type """ return tuple(worker_key.split("@")) @dataclass class ModelControllerParameters(BaseServerParameters): port: Optional[int] = field( default=8000, metadata={"help": "Model Controller deploy port"} ) registry_type: Optional[str] = field( default="embedded", metadata={ "help": "Registry type: embedded, database...", "valid_values": ["embedded", "database"], }, ) registry_db_type: Optional[str] = field( default="mysql", metadata={ "help": "Registry database type, now only support sqlite and mysql, it is " "valid when registry_type is database", "valid_values": ["mysql", "sqlite"], }, ) registry_db_name: Optional[str] = field( default="dbgpt", metadata={ "help": "Registry database name, just for database, it is valid when " "registry_type is database, please set to full database path for sqlite" }, ) registry_db_host: Optional[str] = field( default=None, metadata={ "help": "Registry database host, just for database, it is valid when " "registry_type is database" }, ) registry_db_port: Optional[int] = field( default=None, metadata={ "help": "Registry database port, just for database, it is valid when " "registry_type is database" }, ) registry_db_user: Optional[str] = field( default=None, metadata={ "help": "Registry database user, just for database, it is valid when " "registry_type is database" }, ) registry_db_password: Optional[str] = field( default=None, metadata={ "help": "Registry database password, just for database, it is valid when " "registry_type is database. We recommend to use environment variable to " "store password, you can set it in your environment variable like " "export CONTROLLER_REGISTRY_DB_PASSWORD='your_password'" }, ) registry_db_pool_size: Optional[int] = field( default=5, metadata={ "help": "Registry database pool size, just for database, it is valid when " "registry_type is database" }, ) registry_db_max_overflow: Optional[int] = field( default=10, metadata={ "help": "Registry database max overflow, just for database, it is valid " "when registry_type is database" }, ) heartbeat_interval_secs: Optional[int] = field( default=20, metadata={"help": "The interval for checking heartbeats (seconds)"} ) heartbeat_timeout_secs: Optional[int] = field( default=60, metadata={ "help": "The timeout for checking heartbeats (seconds), it will be set " "unhealthy if the worker is not responding in this time" }, ) log_file: Optional[str] = field( default="dbgpt_model_controller.log", metadata={ "help": "The filename to store log", }, ) tracer_file: Optional[str] = field( default="dbgpt_model_controller_tracer.jsonl", metadata={ "help": "The filename to store tracer span records", }, ) tracer_storage_cls: Optional[str] = field( default=None, metadata={ "help": "The storage class to storage tracer span records", }, ) @dataclass class ModelAPIServerParameters(BaseServerParameters): port: Optional[int] = field( default=8100, metadata={"help": "Model API server deploy port"} ) controller_addr: Optional[str] = field( default="http://127.0.0.1:8000", metadata={"help": "The Model controller address to connect"}, ) api_keys: Optional[str] = field( default=None, metadata={"help": "Optional list of comma separated API keys"}, ) embedding_batch_size: Optional[int] = field( default=None, metadata={"help": "Embedding batch size"} ) log_file: Optional[str] = field( default="dbgpt_model_apiserver.log", metadata={ "help": "The filename to store log", }, ) tracer_file: Optional[str] = field( default="dbgpt_model_apiserver_tracer.jsonl", metadata={ "help": "The filename to store tracer span records", }, ) tracer_storage_cls: Optional[str] = field( default=None, metadata={ "help": "The storage class to storage tracer span records", }, ) @dataclass class BaseModelParameters(BaseParameters): model_name: str = field(metadata={"help": "Model name", "tags": "fixed"}) model_path: str = field(metadata={"help": "Model path", "tags": "fixed"}) @dataclass class ModelWorkerParameters(BaseServerParameters, BaseModelParameters): worker_type: Optional[str] = field( default=None, metadata={"valid_values": WorkerType.values(), "help": "Worker type"}, ) model_alias: Optional[str] = field( default=None, metadata={"help": "model alias"}, ) worker_class: Optional[str] = field( default=None, metadata={"help": "Model worker class, dbgpt.model.cluster.DefaultModelWorker"}, ) model_type: Optional[str] = field( default="huggingface", metadata={ "help": "Model type: huggingface, llama.cpp, proxy and vllm", "tags": "fixed", }, ) port: Optional[int] = field( default=8001, metadata={"help": "Model worker deploy port"} ) limit_model_concurrency: Optional[int] = field( default=5, metadata={"help": "Model concurrency limit"} ) standalone: Optional[bool] = field( default=False, metadata={"help": "Standalone mode. If True, embedded Run ModelController"}, ) register: Optional[bool] = field( default=True, metadata={"help": "Register current worker to model controller"} ) worker_register_host: Optional[str] = field( default=None, metadata={ "help": "The ip address of current worker to register to ModelController. " "If None, the address is automatically determined" }, ) controller_addr: Optional[str] = field( default=None, metadata={"help": "The Model controller address to register"} ) send_heartbeat: Optional[bool] = field( default=True, metadata={"help": "Send heartbeat to model controller"} ) heartbeat_interval: Optional[int] = field( default=20, metadata={"help": "The interval for sending heartbeats (seconds)"} ) log_file: Optional[str] = field( default="dbgpt_model_worker_manager.log", metadata={ "help": "The filename to store log", }, ) tracer_file: Optional[str] = field( default="dbgpt_model_worker_manager_tracer.jsonl", metadata={ "help": "The filename to store tracer span records", }, ) tracer_storage_cls: Optional[str] = field( default=None, metadata={ "help": "The storage class to storage tracer span records", }, ) @dataclass class BaseEmbeddingModelParameters(BaseModelParameters): def build_kwargs(self, **kwargs) -> Dict: pass def is_rerank_model(self) -> bool: """Check if the model is a rerank model""" return False @dataclass class EmbeddingModelParameters(BaseEmbeddingModelParameters): device: Optional[str] = field( default=None, metadata={ "help": "Device to run model. If None, the device is automatically determined" }, ) normalize_embeddings: Optional[bool] = field( default=None, metadata={ "help": "Determines whether the model's embeddings should be normalized." }, ) rerank: Optional[bool] = field( default=False, metadata={"help": "Whether the model is a rerank model"} ) max_length: Optional[int] = field( default=None, metadata={ "help": "Max length for input sequences. Longer sequences will be " "truncated. If None, max length of the model will be used, just for rerank" " model now." }, ) def build_kwargs(self, **kwargs) -> Dict: model_kwargs, encode_kwargs = None, None if self.device: model_kwargs = {"device": self.device} if self.normalize_embeddings: encode_kwargs = {"normalize_embeddings": self.normalize_embeddings} if model_kwargs: kwargs["model_kwargs"] = model_kwargs if self.is_rerank_model(): kwargs["max_length"] = self.max_length if encode_kwargs: kwargs["encode_kwargs"] = encode_kwargs return kwargs def is_rerank_model(self) -> bool: """Check if the model is a rerank model""" return self.rerank @dataclass class ModelParameters(BaseModelParameters): device: Optional[str] = field( default=None, metadata={ "help": "Device to run model. If None, the device is automatically determined" }, ) model_type: Optional[str] = field( default="huggingface", metadata={ "help": "Model type: huggingface, llama.cpp, proxy and vllm", "tags": "fixed", }, ) prompt_template: Optional[str] = field( default=None, metadata={ "help": f"Prompt template. If None, the prompt template is automatically " f"determined from model path" }, ) max_context_size: Optional[int] = field( default=4096, metadata={"help": "Maximum context size"} ) num_gpus: Optional[int] = field( default=None, metadata={ "help": "The number of gpus you expect to use, if it is empty, use all of them as much as possible" }, ) max_gpu_memory: Optional[str] = field( default=None, metadata={ "help": "The maximum memory limit of each GPU, only valid in multi-GPU configuration" }, ) cpu_offloading: Optional[bool] = field( default=False, metadata={"help": "CPU offloading"} ) load_8bit: Optional[bool] = field( default=False, metadata={"help": "8-bit quantization"} ) load_4bit: Optional[bool] = field( default=False, metadata={"help": "4-bit quantization"} ) quant_type: Optional[str] = field( default="nf4", metadata={ "valid_values": ["nf4", "fp4"], "help": "Quantization datatypes, `fp4` (four bit float) and `nf4` (normal four bit float), only valid when load_4bit=True", }, ) use_double_quant: Optional[bool] = field( default=True, metadata={"help": "Nested quantization, only valid when load_4bit=True"}, ) compute_dtype: Optional[str] = field( default=None, metadata={ "valid_values": ["bfloat16", "float16", "float32"], "help": "Model compute type", }, ) trust_remote_code: Optional[bool] = field( default=True, metadata={"help": "Trust remote code"} ) verbose: Optional[bool] = field( default=False, metadata={"help": "Show verbose output."} ) @dataclass class LlamaCppModelParameters(ModelParameters): seed: Optional[int] = field( default=-1, metadata={"help": "Random seed for llama-cpp models. -1 for random"} ) n_threads: Optional[int] = field( default=None, metadata={ "help": "Number of threads to use. If None, the number of threads is automatically determined" }, ) n_batch: Optional[int] = field( default=512, metadata={ "help": "Maximum number of prompt tokens to batch together when calling llama_eval" }, ) n_gpu_layers: Optional[int] = field( default=1000000000, metadata={ "help": "Number of layers to offload to the GPU, Set this to 1000000000 to offload all layers to the GPU." }, ) n_gqa: Optional[int] = field( default=None, metadata={"help": "Grouped-query attention. Must be 8 for llama-2 70b."}, ) rms_norm_eps: Optional[float] = field( default=5e-06, metadata={"help": "5e-6 is a good value for llama-2 models."} ) cache_capacity: Optional[str] = field( default=None, metadata={ "help": "Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed. " }, ) prefer_cpu: Optional[bool] = field( default=False, metadata={ "help": "If a GPU is available, it will be preferred by default, unless prefer_cpu=False is configured." }, ) @dataclass class ProxyModelParameters(BaseModelParameters): proxy_server_url: str = field( metadata={ "help": "Proxy server url, such as: https://api.openai.com/v1/chat/completions" }, ) proxy_api_key: str = field( metadata={"tags": "privacy", "help": "The api key of current proxy LLM"}, ) proxy_api_base: str = field( default=None, metadata={ "help": "The base api address, such as: https://api.openai.com/v1. If None, we will use proxy_api_base first" }, ) proxy_api_app_id: Optional[str] = field( default=None, metadata={ "help": "The app id for current proxy LLM(Just for spark proxy LLM now)." }, ) proxy_api_secret: Optional[str] = field( default=None, metadata={ "help": "The app secret for current proxy LLM(Just for spark proxy LLM now)." }, ) proxy_api_type: Optional[str] = field( default=None, metadata={ "help": "The api type of current proxy the current proxy model, if you use Azure, it can be: azure" }, ) proxy_api_version: Optional[str] = field( default=None, metadata={"help": "The api version of current proxy the current model"}, ) http_proxy: Optional[str] = field( default=os.environ.get("http_proxy") or os.environ.get("https_proxy"), metadata={"help": "The http or https proxy to use openai"}, ) proxyllm_backend: Optional[str] = field( default=None, metadata={ "help": "The model name actually pass to current proxy server url, such " "as gpt-3.5-turbo, gpt-4, chatglm_pro, chatglm_std and so on" }, ) model_type: Optional[str] = field( default="proxy", metadata={ "help": "Model type: huggingface, llama.cpp, proxy and vllm", "tags": "fixed", }, ) device: Optional[str] = field( default=None, metadata={ "help": "Device to run model. If None, the device is automatically " "determined" }, ) prompt_template: Optional[str] = field( default=None, metadata={ "help": f"Prompt template. If None, the prompt template is automatically " f"determined from model path" }, ) max_context_size: Optional[int] = field( default=4096, metadata={"help": "Maximum context size"} ) llm_client_class: Optional[str] = field( default=None, metadata={ "help": "The class name of llm client, such as " "dbgpt.model.proxy.llms.proxy_model.ProxyModel" }, ) def __post_init__(self): if not self.proxy_server_url and self.proxy_api_base: self.proxy_server_url = f"{self.proxy_api_base}/chat/completions" @dataclass class ProxyEmbeddingParameters(BaseEmbeddingModelParameters): proxy_server_url: str = field( metadata={ "help": "Proxy base url(OPENAI_API_BASE), such as https://api.openai.com/v1" }, ) proxy_api_key: str = field( metadata={ "tags": "privacy", "help": "The api key of the current embedding model(OPENAI_API_KEY)", }, ) device: Optional[str] = field( default=None, metadata={"help": "Device to run model. Not working for proxy embedding model"}, ) proxy_api_type: Optional[str] = field( default=None, metadata={ "help": "The api type of current proxy the current embedding model(OPENAI_API_TYPE), if you use Azure, it can be: azure" }, ) proxy_api_secret: str = field( default=None, metadata={ "tags": "privacy", "help": "The api secret of the current embedding model(OPENAI_API_SECRET)", }, ) proxy_api_version: Optional[str] = field( default=None, metadata={ "help": "The api version of current proxy the current embedding model(OPENAI_API_VERSION)" }, ) proxy_backend: Optional[str] = field( default="text-embedding-ada-002", metadata={ "help": "The model name actually pass to current proxy server url, such as text-embedding-ada-002" }, ) proxy_deployment: Optional[str] = field( default="text-embedding-ada-002", metadata={"help": "Tto support Azure OpenAI Service custom deployment names"}, ) rerank: Optional[bool] = field( default=False, metadata={"help": "Whether the model is a rerank model"} ) def build_kwargs(self, **kwargs) -> Dict: params = { "openai_api_base": self.proxy_server_url, "openai_api_key": self.proxy_api_key, "openai_api_type": self.proxy_api_type if self.proxy_api_type else None, "openai_api_version": ( self.proxy_api_version if self.proxy_api_version else None ), "model": self.proxy_backend, "deployment": ( self.proxy_deployment if self.proxy_deployment else self.proxy_backend ), } for k, v in kwargs: params[k] = v return params def is_rerank_model(self) -> bool: """Check if the model is a rerank model""" return self.rerank _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG = { ProxyEmbeddingParameters: "proxy_openai,proxy_azure,proxy_http_openapi,proxy_ollama,proxy_tongyi,proxy_qianfan,rerank_proxy_http_openapi", } EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG = {} def _update_embedding_config(): global EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG for param_cls, models in _EMBEDDING_PARAMETER_CLASS_TO_NAME_CONFIG.items(): models = [m.strip() for m in models.split(",")] for model in models: if model not in EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG: EMBEDDING_NAME_TO_PARAMETER_CLASS_CONFIG[model] = param_cls _update_embedding_config()