mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 19:39:35 +00:00
feat: add sglang support
This commit is contained in:
parent
c68332be4c
commit
7b43a039ac
39
configs/dbgpt-local-sglang.toml
Normal file
39
configs/dbgpt-local-sglang.toml
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
[system]
|
||||||
|
# Load language from environment variable(It is set by the hook)
|
||||||
|
language = "${env:DBGPT_LANG:-zh}"
|
||||||
|
api_keys = []
|
||||||
|
encrypt_key = "your_secret_key"
|
||||||
|
|
||||||
|
# Server Configurations
|
||||||
|
[service.web]
|
||||||
|
host = "0.0.0.0"
|
||||||
|
port = 5670
|
||||||
|
|
||||||
|
[service.web.database]
|
||||||
|
type = "sqlite"
|
||||||
|
path = "pilot/meta_data/dbgpt.db"
|
||||||
|
|
||||||
|
[rag.storage]
|
||||||
|
[rag.storage.vector]
|
||||||
|
type = "chroma"
|
||||||
|
persist_path = "pilot/data"
|
||||||
|
|
||||||
|
# Model Configurations
|
||||||
|
[models]
|
||||||
|
[[models.llms]]
|
||||||
|
name = "DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
provider = "sglang"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
path = "models/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||||
|
# dtype = "float32"
|
||||||
|
|
||||||
|
[[models.embeddings]]
|
||||||
|
name = "BAAI/bge-large-zh-v1.5"
|
||||||
|
provider = "hf"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
path = "/data/models/bge-large-zh-v1.5"
|
||||||
|
|
@ -97,6 +97,7 @@ This tutorial assumes that you can establish network communication with the depe
|
|||||||
{label: 'DeepSeek (proxy)', value: 'deepseek'},
|
{label: 'DeepSeek (proxy)', value: 'deepseek'},
|
||||||
{label: 'GLM4 (local)', value: 'glm-4'},
|
{label: 'GLM4 (local)', value: 'glm-4'},
|
||||||
{label: 'VLLM (local)', value: 'vllm'},
|
{label: 'VLLM (local)', value: 'vllm'},
|
||||||
|
{label: 'SGLang (local)', value: 'sglang'},
|
||||||
{label: 'LLAMA_CPP (local)', value: 'llama_cpp'},
|
{label: 'LLAMA_CPP (local)', value: 'llama_cpp'},
|
||||||
{label: 'Ollama (proxy)', value: 'ollama'},
|
{label: 'Ollama (proxy)', value: 'ollama'},
|
||||||
]}>
|
]}>
|
||||||
@ -291,6 +292,54 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
|
|||||||
```
|
```
|
||||||
|
|
||||||
</TabItem>
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="sglang" label="SGLang(local)">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use uv to install dependencies needed for vllm
|
||||||
|
# Install core dependencies and select desired extensions
|
||||||
|
uv sync --all-packages \
|
||||||
|
--extra "base" \
|
||||||
|
--extra "hf" \
|
||||||
|
--extra "cuda121" \
|
||||||
|
--extra "sglang" \
|
||||||
|
--extra "rag" \
|
||||||
|
--extra "storage_chromadb" \
|
||||||
|
--extra "quant_bnb" \
|
||||||
|
--extra "dbgpts"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run Webserver
|
||||||
|
|
||||||
|
To run DB-GPT with the local model. You can modify the `configs/dbgpt-local-sglang.toml` configuration file to specify the model path and other parameters.
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# Model Configurations
|
||||||
|
[models]
|
||||||
|
[[models.llms]]
|
||||||
|
name = "THUDM/glm-4-9b-chat-hf"
|
||||||
|
provider = "sglang"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
|
||||||
|
[[models.embeddings]]
|
||||||
|
name = "BAAI/bge-large-zh-v1.5"
|
||||||
|
provider = "hf"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
```
|
||||||
|
In the above configuration file, `[[models.llms]]` specifies the LLM model, and `[[models.embeddings]]` specifies the embedding model. If you not provide the `path` parameter, the model will be downloaded from the Hugging Face model hub according to the `name` parameter.
|
||||||
|
|
||||||
|
Then run the following command to start the webserver:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run dbgpt start webserver --config configs/dbgpt-local-sglang.toml
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
||||||
|
|
||||||
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
|
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
|
||||||
|
@ -95,6 +95,9 @@ class I18N:
|
|||||||
"vllm_preset": "VLLM Local Mode",
|
"vllm_preset": "VLLM Local Mode",
|
||||||
"vllm_desc": "Using VLLM framework to load local model, requires GPU environment",
|
"vllm_desc": "Using VLLM framework to load local model, requires GPU environment",
|
||||||
"vllm_note": "Requires local model path configuration",
|
"vllm_note": "Requires local model path configuration",
|
||||||
|
"sglang_preset": "SGLang local Mode",
|
||||||
|
"sglang_desc": "Use SGlang framework to load local model, requires GPU environment",
|
||||||
|
"sglang_note": "Requires local model path configuration",
|
||||||
"llama_cpp_preset": "LLAMA_CPP Local Mode",
|
"llama_cpp_preset": "LLAMA_CPP Local Mode",
|
||||||
"llama_cpp_desc": "Using LLAMA.cpp framework to load local model, can run on CPU but GPU recommended",
|
"llama_cpp_desc": "Using LLAMA.cpp framework to load local model, can run on CPU but GPU recommended",
|
||||||
"llama_cpp_note": 'Requires local model path configuration, for CUDA support set CMAKE_ARGS="-DGGML_CUDA=ON"',
|
"llama_cpp_note": 'Requires local model path configuration, for CUDA support set CMAKE_ARGS="-DGGML_CUDA=ON"',
|
||||||
@ -175,6 +178,9 @@ class I18N:
|
|||||||
"vllm_preset": "VLLM 本地模式",
|
"vllm_preset": "VLLM 本地模式",
|
||||||
"vllm_desc": "使用VLLM框架加载本地模型,需要GPU环境",
|
"vllm_desc": "使用VLLM框架加载本地模型,需要GPU环境",
|
||||||
"vllm_note": "需要配置本地模型路径",
|
"vllm_note": "需要配置本地模型路径",
|
||||||
|
"sglang_preset": "SGLang 本地模式",
|
||||||
|
"sglang_desc": "使用SGLang框架加载本地模型,需要GPU环境",
|
||||||
|
"sglang_note": "需要配置本地模型路径",
|
||||||
"llama_cpp_preset": "LLAMA_CPP 本地模式",
|
"llama_cpp_preset": "LLAMA_CPP 本地模式",
|
||||||
"llama_cpp_desc": "使用LLAMA.cpp框架加载本地模型,CPU也可运行但推荐GPU",
|
"llama_cpp_desc": "使用LLAMA.cpp框架加载本地模型,CPU也可运行但推荐GPU",
|
||||||
"llama_cpp_note": '需要配置本地模型路径,启用CUDA需设置CMAKE_ARGS="-DGGML_CUDA=ON"',
|
"llama_cpp_note": '需要配置本地模型路径,启用CUDA需设置CMAKE_ARGS="-DGGML_CUDA=ON"',
|
||||||
@ -361,6 +367,21 @@ def get_deployment_presets():
|
|||||||
"description": i18n.get("vllm_desc"),
|
"description": i18n.get("vllm_desc"),
|
||||||
"note": i18n.get("vllm_note"),
|
"note": i18n.get("vllm_note"),
|
||||||
},
|
},
|
||||||
|
i18n.get("sglang_preset"): {
|
||||||
|
"extras": [
|
||||||
|
"base",
|
||||||
|
"hf",
|
||||||
|
"cuda121",
|
||||||
|
"sglang",
|
||||||
|
"rag",
|
||||||
|
"storage_chromadb",
|
||||||
|
"quant_bnb",
|
||||||
|
"dbgpts",
|
||||||
|
],
|
||||||
|
"config": "configs/dbgpt-local-sglang.toml",
|
||||||
|
"description": i18n.get("sgalang_desc"),
|
||||||
|
"note": i18n.get("sglang_note"),
|
||||||
|
},
|
||||||
i18n.get("llama_cpp_preset"): {
|
i18n.get("llama_cpp_preset"): {
|
||||||
"extras": [
|
"extras": [
|
||||||
"base",
|
"base",
|
||||||
|
@ -40,6 +40,7 @@ class BaseDeployModelParameters(BaseParameters):
|
|||||||
"llama_cpp_server",
|
"llama_cpp_server",
|
||||||
"proxy/*",
|
"proxy/*",
|
||||||
"vllm",
|
"vllm",
|
||||||
|
"sglang",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -40,6 +40,7 @@ def scan_model_providers():
|
|||||||
"hf_adapter",
|
"hf_adapter",
|
||||||
"llama_cpp_adapter",
|
"llama_cpp_adapter",
|
||||||
"llama_cpp_py_adapter",
|
"llama_cpp_py_adapter",
|
||||||
|
"sglang_adapter",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
config_llms = ScannerConfig(
|
config_llms = ScannerConfig(
|
||||||
|
373
packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
Normal file
373
packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field, fields, is_dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from dbgpt.core import ModelMessage
|
||||||
|
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||||
|
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
|
||||||
|
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.util.i18n_utils import _
|
||||||
|
from dbgpt.util.parameter_utils import (
|
||||||
|
_get_dataclass_print_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SGlangDeployModelParameters(LLMDeployModelParameters):
|
||||||
|
"""SGlang deploy model parameters"""
|
||||||
|
|
||||||
|
provider: str = "sglang"
|
||||||
|
|
||||||
|
path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"order": -800,
|
||||||
|
"help": _("The path of the model, if you want to deploy a local model."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
device: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"order": -700,
|
||||||
|
"help": _(
|
||||||
|
"The device to run the model, 'auto' means using the default device."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
concurrency: Optional[int] = field(
|
||||||
|
default=100, metadata={"help": _("Model concurrency limit")}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def real_model_path(self) -> Optional[str]:
|
||||||
|
"""Get the real model path
|
||||||
|
|
||||||
|
If deploy model is not local, return None
|
||||||
|
"""
|
||||||
|
return self._resolve_root_path(self.path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def real_device(self) -> Optional[str]:
|
||||||
|
"""Get the real device"""
|
||||||
|
|
||||||
|
return self.device or super().real_device
|
||||||
|
|
||||||
|
def to_sglang_params(
|
||||||
|
self, sglang_config_cls: Optional[Type] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Convert to sglang deploy model parameters"""
|
||||||
|
|
||||||
|
data = self.to_dict()
|
||||||
|
model = data.get("path", None)
|
||||||
|
if not model:
|
||||||
|
model = data.get("name", None)
|
||||||
|
if not model:
|
||||||
|
raise ValueError(
|
||||||
|
_("Model is required, please pecify the model path or name.")
|
||||||
|
)
|
||||||
|
|
||||||
|
copy_data = data.copy()
|
||||||
|
real_params = {}
|
||||||
|
extra_params = copy_data.get("extras", {})
|
||||||
|
if sglang_config_cls and is_dataclass(sglang_config_cls):
|
||||||
|
for fd in fields(sglang_config_cls):
|
||||||
|
if fd.name in copy_data:
|
||||||
|
real_params[fd.name] = copy_data[fd.name]
|
||||||
|
|
||||||
|
else:
|
||||||
|
for k, v in copy_data.items():
|
||||||
|
if k in [
|
||||||
|
"provider",
|
||||||
|
"path",
|
||||||
|
"name",
|
||||||
|
"extras",
|
||||||
|
"verbose",
|
||||||
|
"backend",
|
||||||
|
"prompt_template",
|
||||||
|
"context_length",
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
real_params[k] = v
|
||||||
|
|
||||||
|
real_params["model"] = model
|
||||||
|
if extra_params and isinstance(extra_params, dict):
|
||||||
|
real_params.update(extra_params)
|
||||||
|
return real_params
|
||||||
|
|
||||||
|
trust_remote_code: Optional[bool] = field(
|
||||||
|
default=True, metadata={"help": _("Trust remote code or not.")}
|
||||||
|
)
|
||||||
|
|
||||||
|
download_dir: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Directory to download and load the weights, "
|
||||||
|
"default to the default cache dir of "
|
||||||
|
"buggingface."
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
load_format: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"The format of the model weights to load.\n\n"
|
||||||
|
'* "auto" will try to load the weights in the safetensors format '
|
||||||
|
"and fall back to the pytorch bin format if safetensors format "
|
||||||
|
"is not available.\n"
|
||||||
|
'* "pt" will load the weights in the pytorch bin format.\n'
|
||||||
|
'* "safetensors" will load the weights in the safetensors format.\n'
|
||||||
|
'* "npcache" will load the weights in pytorch format and store '
|
||||||
|
"a numpy cache to speed up the loading.\n"
|
||||||
|
'* "dummy" will initialize the weights with random values, '
|
||||||
|
"which is mainly for profiling.\n"
|
||||||
|
'* "tensorizer" will load the weights using tensorizer from '
|
||||||
|
"CoreWeave. See the Tensorize vLLM Model script in the Examples "
|
||||||
|
"section for more information.\n"
|
||||||
|
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
||||||
|
"Model Streamer \n"
|
||||||
|
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||||
|
"quantization.\n"
|
||||||
|
),
|
||||||
|
"valid_values": [
|
||||||
|
"auto",
|
||||||
|
"pt",
|
||||||
|
"safetensors",
|
||||||
|
"npcache",
|
||||||
|
"dummy",
|
||||||
|
"tensorizer",
|
||||||
|
"runai_streamer",
|
||||||
|
"bitsandbytes",
|
||||||
|
"sharded_state",
|
||||||
|
"gguf",
|
||||||
|
"mistral",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
config_format: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"The format of the model config to load.\n\n"
|
||||||
|
'* "auto" will try to load the config in hf format '
|
||||||
|
"if available else it will try to load in mistral format "
|
||||||
|
),
|
||||||
|
"valid_values": [
|
||||||
|
"auto",
|
||||||
|
"hf",
|
||||||
|
"mistral",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dtype: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Data type for model weights and activations.\n\n"
|
||||||
|
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||||
|
"BF16 precision for BF16 models.\n"
|
||||||
|
'* "half" for FP16.\n'
|
||||||
|
'* "float16" is the same as "half".\n'
|
||||||
|
'* "bfloat16" for a balance between precision and range.\n'
|
||||||
|
'* "float" is shorthand for FP32 precision.\n'
|
||||||
|
'* "float32" for FP32 precision.'
|
||||||
|
),
|
||||||
|
"valid_values": ["auto", "half", "float16", "bfloat16", "float", "float32"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
max_model_len: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Model context length. If unspecified, will be automatically derived "
|
||||||
|
"from the model config."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tensor_parallel_size: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={
|
||||||
|
"help": _("Number of tensor parallel replicas."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_num_seqs: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("Maximum number of sequences per iteration."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
revision: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"The specific model version to use. It can be a branch "
|
||||||
|
"name, a tag name, or a commit id. If unspecified, will use "
|
||||||
|
"the default version."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
tokenizer_revision: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Revision of the huggingface tokenizer to use. "
|
||||||
|
"It can be a branch name, a tag name, or a commit id. "
|
||||||
|
"If unspecified, will use the default version."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
quantization: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _(
|
||||||
|
"Method used to quantize the weights. If "
|
||||||
|
"None, we first check the `quantization_config` "
|
||||||
|
"attribute in the model config file. If that is "
|
||||||
|
"None, we assume the model weights are not "
|
||||||
|
"quantized and use `dtype` to determine the data "
|
||||||
|
"type of the weights."
|
||||||
|
),
|
||||||
|
"valid_values": ["awq", "gptq", "int8"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
gpu_memory_utilization: float = field(
|
||||||
|
default=0.90,
|
||||||
|
metadata={
|
||||||
|
"help": _("The fraction of GPU memory to be used for the model."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
extras: Optional[Dict] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": _("Extra parameters, it will be passed to the sglang engine.")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangModelAdapterWrapper(LLMModelAdapter):
|
||||||
|
"""Wrapping sglang engine"""
|
||||||
|
|
||||||
|
def __init__(self, conv_factory: Optional[ConversationAdapterFactory] = None):
|
||||||
|
if not conv_factory:
|
||||||
|
from dbgpt.model.adapter.model_adapter import (
|
||||||
|
DefaultConversationAdapterFactory,
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_factory = DefaultConversationAdapterFactory()
|
||||||
|
|
||||||
|
self.conv_factory = conv_factory
|
||||||
|
|
||||||
|
def new_adapter(self, **kwargs) -> "SGLangModelAdapterWrapper":
|
||||||
|
new_obj = super().new_adapter(**kwargs)
|
||||||
|
new_obj.conv_factory = self.conv_factory
|
||||||
|
return new_obj # type: ignore
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return ModelType.SGLANG
|
||||||
|
|
||||||
|
def model_param_class(
|
||||||
|
self, model_type: str = None
|
||||||
|
) -> Type[SGlangDeployModelParameters]:
|
||||||
|
"""Get model parameters class."""
|
||||||
|
return SGlangDeployModelParameters
|
||||||
|
|
||||||
|
def match(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
return provider == ModelType.SGLANG
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
convert_to_compatible_format: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
if not tokenizer:
|
||||||
|
raise ValueError("tokenizer is None")
|
||||||
|
|
||||||
|
if hasattr(tokenizer, "apply_chat_template"):
|
||||||
|
messages = self.transform_model_messages(
|
||||||
|
messages, convert_to_compatible_format
|
||||||
|
)
|
||||||
|
logger.debug(f"The messages after transform: \n{messages}")
|
||||||
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return str_prompt
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_from_params(self, params: SGlangDeployModelParameters):
|
||||||
|
try:
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.srt.managers.server import AsyncServerManager
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install sglang first: pip install sglang")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" Start SGLang AsyncServerManager with args: \
|
||||||
|
{_get_dataclass_print_str(params)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sglang_args_dict = params.to_sglang_params()
|
||||||
|
model_path = sglang_args_dict.pop("model")
|
||||||
|
|
||||||
|
# 创建SGLang服务器配置
|
||||||
|
server_config = sgl.RuntimeConfig(
|
||||||
|
model=model_path,
|
||||||
|
tensor_parallel_size=params.tensor_parallel_size,
|
||||||
|
max_model_len=params.max_model_len or 4096,
|
||||||
|
dtype=params.dtype if params.dtype != "auto" else None,
|
||||||
|
quantization=params.quantization,
|
||||||
|
gpu_memory_utilization=params.gpu_memory_utilization,
|
||||||
|
**sglang_args_dict.get("extras", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建异步服务器管理器
|
||||||
|
engine = AsyncServerManager(server_config)
|
||||||
|
|
||||||
|
# 获取tokenizer
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
trust_remote_code=params.trust_remote_code,
|
||||||
|
revision=params.tokenizer_revision,
|
||||||
|
)
|
||||||
|
|
||||||
|
return engine, tokenizer
|
||||||
|
|
||||||
|
def support_async(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_async_generate_stream_function(
|
||||||
|
self, model, deploy_model_params: LLMDeployModelParameters
|
||||||
|
):
|
||||||
|
from dbgpt.model.llm.llm_out.sglang_llm import generate_stream
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> ConversationAdapter:
|
||||||
|
return self.conv_factory.get_by_model(model_name, model_path)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
register_model_adapter(SGLangModelAdapterWrapper, supported_models=COMMON_HF_MODELS)
|
@ -17,6 +17,7 @@ class ModelType:
|
|||||||
LLAMA_CPP_SERVER = "llama.cpp.server"
|
LLAMA_CPP_SERVER = "llama.cpp.server"
|
||||||
PROXY = "proxy"
|
PROXY = "proxy"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
|
SGLANG = "sglang"
|
||||||
# TODO, support more model type
|
# TODO, support more model type
|
||||||
|
|
||||||
|
|
||||||
|
169
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
Normal file
169
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncIterator, Dict, List
|
||||||
|
|
||||||
|
from dbgpt.core import (
|
||||||
|
ChatCompletionResponse,
|
||||||
|
ChatCompletionResponseChoice,
|
||||||
|
ChatCompletionResponseStreamChoice,
|
||||||
|
ChatCompletionStreamResponse,
|
||||||
|
ChatMessage,
|
||||||
|
DeltaMessage,
|
||||||
|
ModelMessage,
|
||||||
|
ModelMessageRole,
|
||||||
|
)
|
||||||
|
from dbgpt.model.parameter import ModelParameters
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_stream(
|
||||||
|
model: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
model_messages: List[ModelMessage],
|
||||||
|
model_parameters: ModelParameters,
|
||||||
|
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||||
|
"""Generate stream response using SGLang."""
|
||||||
|
try:
|
||||||
|
import sglang as sgl
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install sglang first: pip install sglang")
|
||||||
|
|
||||||
|
# Message format convert
|
||||||
|
messages = []
|
||||||
|
for msg in model_messages:
|
||||||
|
role = msg.role
|
||||||
|
if role == ModelMessageRole.HUMAN:
|
||||||
|
role = "user"
|
||||||
|
elif role == ModelMessageRole.SYSTEM:
|
||||||
|
role = "system"
|
||||||
|
elif role == ModelMessageRole.AI:
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
messages.append({"role": role, "content": msg.content})
|
||||||
|
|
||||||
|
# Model params set
|
||||||
|
temperature = model_parameters.temperature
|
||||||
|
top_p = model_parameters.top_p
|
||||||
|
max_tokens = model_parameters.max_new_tokens
|
||||||
|
|
||||||
|
# Create SGLang request
|
||||||
|
async def stream_generator():
|
||||||
|
# Use SGLang async API generate
|
||||||
|
state = sgl.RuntimeState()
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def chat(state, messages):
|
||||||
|
sgl.gen(
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start task generate
|
||||||
|
task = model.submit_task(chat, state, messages)
|
||||||
|
|
||||||
|
# Fetch result
|
||||||
|
generated_text = ""
|
||||||
|
async for output in task.stream_output():
|
||||||
|
if hasattr(output, "text"):
|
||||||
|
new_text = output.text
|
||||||
|
delta = new_text[len(generated_text) :]
|
||||||
|
generated_text = new_text
|
||||||
|
|
||||||
|
# Create Stream reponse
|
||||||
|
choice = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant", content=delta),
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=params.get("id", "chatcmpl-default"),
|
||||||
|
model=params.get("model", "sglang-model"),
|
||||||
|
choices=[choice],
|
||||||
|
created=int(asyncio.get_event_loop().time()),
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
# Send complete signal
|
||||||
|
choice = ChatCompletionResponseStreamChoice(
|
||||||
|
index=0,
|
||||||
|
delta=DeltaMessage(role="assistant", content=""),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
chunk = ChatCompletionStreamResponse(
|
||||||
|
id=params.get("id", "chatcmpl-default"),
|
||||||
|
model=params.get("model", "sglang-model"),
|
||||||
|
choices=[choice],
|
||||||
|
created=int(asyncio.get_event_loop().time()),
|
||||||
|
)
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async for chunk in stream_generator():
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def generate(
|
||||||
|
model: Any,
|
||||||
|
tokenizer: Any,
|
||||||
|
params: Dict[str, Any],
|
||||||
|
model_messages: List[ModelMessage],
|
||||||
|
model_parameters: ModelParameters,
|
||||||
|
) -> ChatCompletionResponse:
|
||||||
|
"""Generate completion using SGLang."""
|
||||||
|
try:
|
||||||
|
import sglang as sgl
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install sglang first: pip install sglang")
|
||||||
|
|
||||||
|
# Convert format to SGlang
|
||||||
|
messages = []
|
||||||
|
for msg in model_messages:
|
||||||
|
role = msg.role
|
||||||
|
if role == ModelMessageRole.HUMAN:
|
||||||
|
role = "user"
|
||||||
|
elif role == ModelMessageRole.SYSTEM:
|
||||||
|
role = "system"
|
||||||
|
elif role == ModelMessageRole.AI:
|
||||||
|
role = "assistant"
|
||||||
|
else:
|
||||||
|
role = "user"
|
||||||
|
|
||||||
|
messages.append({"role": role, "content": msg.content})
|
||||||
|
|
||||||
|
temperature = model_parameters.temperature
|
||||||
|
top_p = model_parameters.top_p
|
||||||
|
max_tokens = model_parameters.max_new_tokens
|
||||||
|
|
||||||
|
state = sgl.RuntimeState()
|
||||||
|
|
||||||
|
@sgl.function
|
||||||
|
def chat(state, messages):
|
||||||
|
sgl.gen(
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
task = await model.submit_task(chat, state, messages)
|
||||||
|
result = await task.wait()
|
||||||
|
|
||||||
|
choice = ChatCompletionResponseChoice(
|
||||||
|
index=0,
|
||||||
|
message=ChatMessage(role="assistant", content=result.text),
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ChatCompletionResponse(
|
||||||
|
id=params.get("id", "chatcmpl-default"),
|
||||||
|
model=params.get("model", "sglang-model"),
|
||||||
|
choices=[choice],
|
||||||
|
created=int(asyncio.get_event_loop().time()),
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
Loading…
Reference in New Issue
Block a user