diff --git a/configs/dbgpt-local-sglang.toml b/configs/dbgpt-local-sglang.toml
new file mode 100644
index 000000000..2bb5710e1
--- /dev/null
+++ b/configs/dbgpt-local-sglang.toml
@@ -0,0 +1,39 @@
+[system]
+# Load language from environment variable(It is set by the hook)
+language = "${env:DBGPT_LANG:-zh}"
+api_keys = []
+encrypt_key = "your_secret_key"
+
+# Server Configurations
+[service.web]
+host = "0.0.0.0"
+port = 5670
+
+[service.web.database]
+type = "sqlite"
+path = "pilot/meta_data/dbgpt.db"
+
+[rag.storage]
+[rag.storage.vector]
+type = "chroma"
+persist_path = "pilot/data"
+
+# Model Configurations
+[models]
+[[models.llms]]
+name = "DeepSeek-R1-Distill-Qwen-1.5B"
+provider = "sglang"
+# If not provided, the model will be downloaded from the Hugging Face model hub
+# uncomment the following line to specify the model path in the local file system
+# path = "the-model-path-in-the-local-file-system"
+path = "models/DeepSeek-R1-Distill-Qwen-1.5B"
+# dtype = "float32"
+
+[[models.embeddings]]
+name = "BAAI/bge-large-zh-v1.5"
+provider = "hf"
+# If not provided, the model will be downloaded from the Hugging Face model hub
+# uncomment the following line to specify the model path in the local file system
+# path = "the-model-path-in-the-local-file-system"
+path = "/data/models/bge-large-zh-v1.5"
+
diff --git a/docs/docs/quickstart.md b/docs/docs/quickstart.md
index d635cb40f..19cca7843 100644
--- a/docs/docs/quickstart.md
+++ b/docs/docs/quickstart.md
@@ -97,6 +97,7 @@ This tutorial assumes that you can establish network communication with the depe
{label: 'DeepSeek (proxy)', value: 'deepseek'},
{label: 'GLM4 (local)', value: 'glm-4'},
{label: 'VLLM (local)', value: 'vllm'},
+ {label: 'SGLang (local)', value: 'sglang'},
{label: 'LLAMA_CPP (local)', value: 'llama_cpp'},
{label: 'Ollama (proxy)', value: 'ollama'},
]}>
@@ -291,6 +292,54 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
```
+
+
+
+```bash
+# Use uv to install dependencies needed for vllm
+# Install core dependencies and select desired extensions
+uv sync --all-packages \
+--extra "base" \
+--extra "hf" \
+--extra "cuda121" \
+--extra "sglang" \
+--extra "rag" \
+--extra "storage_chromadb" \
+--extra "quant_bnb" \
+--extra "dbgpts"
+```
+
+### Run Webserver
+
+To run DB-GPT with the local model. You can modify the `configs/dbgpt-local-sglang.toml` configuration file to specify the model path and other parameters.
+
+```toml
+# Model Configurations
+[models]
+[[models.llms]]
+name = "THUDM/glm-4-9b-chat-hf"
+provider = "sglang"
+# If not provided, the model will be downloaded from the Hugging Face model hub
+# uncomment the following line to specify the model path in the local file system
+# path = "the-model-path-in-the-local-file-system"
+
+[[models.embeddings]]
+name = "BAAI/bge-large-zh-v1.5"
+provider = "hf"
+# If not provided, the model will be downloaded from the Hugging Face model hub
+# uncomment the following line to specify the model path in the local file system
+# path = "the-model-path-in-the-local-file-system"
+```
+In the above configuration file, `[[models.llms]]` specifies the LLM model, and `[[models.embeddings]]` specifies the embedding model. If you not provide the `path` parameter, the model will be downloaded from the Hugging Face model hub according to the `name` parameter.
+
+Then run the following command to start the webserver:
+
+```bash
+uv run dbgpt start webserver --config configs/dbgpt-local-sglang.toml
+```
+
+
+
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
diff --git a/install_help.py b/install_help.py
index 958ac524a..9aec3731a 100755
--- a/install_help.py
+++ b/install_help.py
@@ -95,6 +95,9 @@ class I18N:
"vllm_preset": "VLLM Local Mode",
"vllm_desc": "Using VLLM framework to load local model, requires GPU environment",
"vllm_note": "Requires local model path configuration",
+ "sglang_preset": "SGLang local Mode",
+ "sglang_desc": "Use SGlang framework to load local model, requires GPU environment",
+ "sglang_note": "Requires local model path configuration",
"llama_cpp_preset": "LLAMA_CPP Local Mode",
"llama_cpp_desc": "Using LLAMA.cpp framework to load local model, can run on CPU but GPU recommended",
"llama_cpp_note": 'Requires local model path configuration, for CUDA support set CMAKE_ARGS="-DGGML_CUDA=ON"',
@@ -175,6 +178,9 @@ class I18N:
"vllm_preset": "VLLM 本地模式",
"vllm_desc": "使用VLLM框架加载本地模型,需要GPU环境",
"vllm_note": "需要配置本地模型路径",
+ "sglang_preset": "SGLang 本地模式",
+ "sglang_desc": "使用SGLang框架加载本地模型,需要GPU环境",
+ "sglang_note": "需要配置本地模型路径",
"llama_cpp_preset": "LLAMA_CPP 本地模式",
"llama_cpp_desc": "使用LLAMA.cpp框架加载本地模型,CPU也可运行但推荐GPU",
"llama_cpp_note": '需要配置本地模型路径,启用CUDA需设置CMAKE_ARGS="-DGGML_CUDA=ON"',
@@ -361,6 +367,21 @@ def get_deployment_presets():
"description": i18n.get("vllm_desc"),
"note": i18n.get("vllm_note"),
},
+ i18n.get("sglang_preset"): {
+ "extras": [
+ "base",
+ "hf",
+ "cuda121",
+ "sglang",
+ "rag",
+ "storage_chromadb",
+ "quant_bnb",
+ "dbgpts",
+ ],
+ "config": "configs/dbgpt-local-sglang.toml",
+ "description": i18n.get("sgalang_desc"),
+ "note": i18n.get("sglang_note"),
+ },
i18n.get("llama_cpp_preset"): {
"extras": [
"base",
diff --git a/packages/dbgpt-core/src/dbgpt/core/interface/parameter.py b/packages/dbgpt-core/src/dbgpt/core/interface/parameter.py
index da51b0fbf..54f5d4290 100644
--- a/packages/dbgpt-core/src/dbgpt/core/interface/parameter.py
+++ b/packages/dbgpt-core/src/dbgpt/core/interface/parameter.py
@@ -40,6 +40,7 @@ class BaseDeployModelParameters(BaseParameters):
"llama_cpp_server",
"proxy/*",
"vllm",
+ "sglang",
],
},
)
diff --git a/packages/dbgpt-core/src/dbgpt/model/__init__.py b/packages/dbgpt-core/src/dbgpt/model/__init__.py
index 1e22c3e45..ffbee2091 100644
--- a/packages/dbgpt-core/src/dbgpt/model/__init__.py
+++ b/packages/dbgpt-core/src/dbgpt/model/__init__.py
@@ -40,6 +40,7 @@ def scan_model_providers():
"hf_adapter",
"llama_cpp_adapter",
"llama_cpp_py_adapter",
+ "sglang_adapter",
],
)
config_llms = ScannerConfig(
diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py b/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
new file mode 100644
index 000000000..27b104668
--- /dev/null
+++ b/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
@@ -0,0 +1,373 @@
+import logging
+from dataclasses import dataclass, field, fields, is_dataclass
+from typing import Any, Dict, List, Optional, Type
+
+from dbgpt.core import ModelMessage
+from dbgpt.core.interface.parameter import LLMDeployModelParameters
+from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
+from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
+from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
+from dbgpt.model.base import ModelType
+from dbgpt.util.i18n_utils import _
+from dbgpt.util.parameter_utils import (
+ _get_dataclass_print_str,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class SGlangDeployModelParameters(LLMDeployModelParameters):
+ """SGlang deploy model parameters"""
+
+ provider: str = "sglang"
+
+ path: Optional[str] = field(
+ default=None,
+ metadata={
+ "order": -800,
+ "help": _("The path of the model, if you want to deploy a local model."),
+ },
+ )
+
+ device: Optional[str] = field(
+ default="auto",
+ metadata={
+ "order": -700,
+ "help": _(
+ "The device to run the model, 'auto' means using the default device."
+ ),
+ },
+ )
+
+ concurrency: Optional[int] = field(
+ default=100, metadata={"help": _("Model concurrency limit")}
+ )
+
+ @property
+ def real_model_path(self) -> Optional[str]:
+ """Get the real model path
+
+ If deploy model is not local, return None
+ """
+ return self._resolve_root_path(self.path)
+
+ @property
+ def real_device(self) -> Optional[str]:
+ """Get the real device"""
+
+ return self.device or super().real_device
+
+ def to_sglang_params(
+ self, sglang_config_cls: Optional[Type] = None
+ ) -> Dict[str, Any]:
+ """Convert to sglang deploy model parameters"""
+
+ data = self.to_dict()
+ model = data.get("path", None)
+ if not model:
+ model = data.get("name", None)
+ if not model:
+ raise ValueError(
+ _("Model is required, please pecify the model path or name.")
+ )
+
+ copy_data = data.copy()
+ real_params = {}
+ extra_params = copy_data.get("extras", {})
+ if sglang_config_cls and is_dataclass(sglang_config_cls):
+ for fd in fields(sglang_config_cls):
+ if fd.name in copy_data:
+ real_params[fd.name] = copy_data[fd.name]
+
+ else:
+ for k, v in copy_data.items():
+ if k in [
+ "provider",
+ "path",
+ "name",
+ "extras",
+ "verbose",
+ "backend",
+ "prompt_template",
+ "context_length",
+ ]:
+ continue
+ real_params[k] = v
+
+ real_params["model"] = model
+ if extra_params and isinstance(extra_params, dict):
+ real_params.update(extra_params)
+ return real_params
+
+ trust_remote_code: Optional[bool] = field(
+ default=True, metadata={"help": _("Trust remote code or not.")}
+ )
+
+ download_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": _(
+ "Directory to download and load the weights, "
+ "default to the default cache dir of "
+ "buggingface."
+ )
+ },
+ )
+
+ load_format: Optional[str] = field(
+ default="auto",
+ metadata={
+ "help": _(
+ "The format of the model weights to load.\n\n"
+ '* "auto" will try to load the weights in the safetensors format '
+ "and fall back to the pytorch bin format if safetensors format "
+ "is not available.\n"
+ '* "pt" will load the weights in the pytorch bin format.\n'
+ '* "safetensors" will load the weights in the safetensors format.\n'
+ '* "npcache" will load the weights in pytorch format and store '
+ "a numpy cache to speed up the loading.\n"
+ '* "dummy" will initialize the weights with random values, '
+ "which is mainly for profiling.\n"
+ '* "tensorizer" will load the weights using tensorizer from '
+ "CoreWeave. See the Tensorize vLLM Model script in the Examples "
+ "section for more information.\n"
+ '* "runai_streamer" will load the Safetensors weights using Run:ai'
+ "Model Streamer \n"
+ '* "bitsandbytes" will load the weights using bitsandbytes '
+ "quantization.\n"
+ ),
+ "valid_values": [
+ "auto",
+ "pt",
+ "safetensors",
+ "npcache",
+ "dummy",
+ "tensorizer",
+ "runai_streamer",
+ "bitsandbytes",
+ "sharded_state",
+ "gguf",
+ "mistral",
+ ],
+ },
+ )
+ config_format: Optional[str] = field(
+ default="auto",
+ metadata={
+ "help": _(
+ "The format of the model config to load.\n\n"
+ '* "auto" will try to load the config in hf format '
+ "if available else it will try to load in mistral format "
+ ),
+ "valid_values": [
+ "auto",
+ "hf",
+ "mistral",
+ ],
+ },
+ )
+
+ dtype: Optional[str] = field(
+ default="auto",
+ metadata={
+ "help": _(
+ "Data type for model weights and activations.\n\n"
+ '* "auto" will use FP16 precision for FP32 and FP16 models, and '
+ "BF16 precision for BF16 models.\n"
+ '* "half" for FP16.\n'
+ '* "float16" is the same as "half".\n'
+ '* "bfloat16" for a balance between precision and range.\n'
+ '* "float" is shorthand for FP32 precision.\n'
+ '* "float32" for FP32 precision.'
+ ),
+ "valid_values": ["auto", "half", "float16", "bfloat16", "float", "float32"],
+ },
+ )
+
+ max_model_len: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": _(
+ "Model context length. If unspecified, will be automatically derived "
+ "from the model config."
+ ),
+ },
+ )
+ tensor_parallel_size: int = field(
+ default=1,
+ metadata={
+ "help": _("Number of tensor parallel replicas."),
+ },
+ )
+ max_num_seqs: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": _("Maximum number of sequences per iteration."),
+ },
+ )
+ revision: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": _(
+ "The specific model version to use. It can be a branch "
+ "name, a tag name, or a commit id. If unspecified, will use "
+ "the default version."
+ ),
+ },
+ )
+ tokenizer_revision: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": _(
+ "Revision of the huggingface tokenizer to use. "
+ "It can be a branch name, a tag name, or a commit id. "
+ "If unspecified, will use the default version."
+ ),
+ },
+ )
+ quantization: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": _(
+ "Method used to quantize the weights. If "
+ "None, we first check the `quantization_config` "
+ "attribute in the model config file. If that is "
+ "None, we assume the model weights are not "
+ "quantized and use `dtype` to determine the data "
+ "type of the weights."
+ ),
+ "valid_values": ["awq", "gptq", "int8"],
+ },
+ )
+ gpu_memory_utilization: float = field(
+ default=0.90,
+ metadata={
+ "help": _("The fraction of GPU memory to be used for the model."),
+ },
+ )
+ extras: Optional[Dict] = field(
+ default=None,
+ metadata={
+ "help": _("Extra parameters, it will be passed to the sglang engine.")
+ },
+ )
+
+
+class SGLangModelAdapterWrapper(LLMModelAdapter):
+ """Wrapping sglang engine"""
+
+ def __init__(self, conv_factory: Optional[ConversationAdapterFactory] = None):
+ if not conv_factory:
+ from dbgpt.model.adapter.model_adapter import (
+ DefaultConversationAdapterFactory,
+ )
+
+ conv_factory = DefaultConversationAdapterFactory()
+
+ self.conv_factory = conv_factory
+
+ def new_adapter(self, **kwargs) -> "SGLangModelAdapterWrapper":
+ new_obj = super().new_adapter(**kwargs)
+ new_obj.conv_factory = self.conv_factory
+ return new_obj # type: ignore
+
+ def model_type(self) -> str:
+ return ModelType.SGLANG
+
+ def model_param_class(
+ self, model_type: str = None
+ ) -> Type[SGlangDeployModelParameters]:
+ """Get model parameters class."""
+ return SGlangDeployModelParameters
+
+ def match(
+ self,
+ provider: str,
+ model_name: Optional[str] = None,
+ model_path: Optional[str] = None,
+ ) -> bool:
+ return provider == ModelType.SGLANG
+
+ def get_str_prompt(
+ self,
+ params: Dict,
+ messages: List[ModelMessage],
+ tokenizer: Any,
+ prompt_template: str = None,
+ convert_to_compatible_format: bool = False,
+ ) -> Optional[str]:
+ if not tokenizer:
+ raise ValueError("tokenizer is None")
+
+ if hasattr(tokenizer, "apply_chat_template"):
+ messages = self.transform_model_messages(
+ messages, convert_to_compatible_format
+ )
+ logger.debug(f"The messages after transform: \n{messages}")
+ str_prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return str_prompt
+ return None
+
+ def load_from_params(self, params: SGlangDeployModelParameters):
+ try:
+ import sglang as sgl
+ from sglang.srt.managers.server import AsyncServerManager
+ except ImportError:
+ raise ImportError("Please install sglang first: pip install sglang")
+
+ logger.info(
+ f" Start SGLang AsyncServerManager with args: \
+ {_get_dataclass_print_str(params)}"
+ )
+
+ sglang_args_dict = params.to_sglang_params()
+ model_path = sglang_args_dict.pop("model")
+
+ # 创建SGLang服务器配置
+ server_config = sgl.RuntimeConfig(
+ model=model_path,
+ tensor_parallel_size=params.tensor_parallel_size,
+ max_model_len=params.max_model_len or 4096,
+ dtype=params.dtype if params.dtype != "auto" else None,
+ quantization=params.quantization,
+ gpu_memory_utilization=params.gpu_memory_utilization,
+ **sglang_args_dict.get("extras", {}),
+ )
+
+ # 创建异步服务器管理器
+ engine = AsyncServerManager(server_config)
+
+ # 获取tokenizer
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ trust_remote_code=params.trust_remote_code,
+ revision=params.tokenizer_revision,
+ )
+
+ return engine, tokenizer
+
+ def support_async(self) -> bool:
+ return True
+
+ def get_async_generate_stream_function(
+ self, model, deploy_model_params: LLMDeployModelParameters
+ ):
+ from dbgpt.model.llm.llm_out.sglang_llm import generate_stream
+
+ return generate_stream
+
+ def get_default_conv_template(
+ self, model_name: str, model_path: str
+ ) -> ConversationAdapter:
+ return self.conv_factory.get_by_model(model_name, model_path)
+
+ def __str__(self) -> str:
+ return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
+
+
+register_model_adapter(SGLangModelAdapterWrapper, supported_models=COMMON_HF_MODELS)
diff --git a/packages/dbgpt-core/src/dbgpt/model/base.py b/packages/dbgpt-core/src/dbgpt/model/base.py
index 38fedfc6b..7ce258e0e 100644
--- a/packages/dbgpt-core/src/dbgpt/model/base.py
+++ b/packages/dbgpt-core/src/dbgpt/model/base.py
@@ -17,6 +17,7 @@ class ModelType:
LLAMA_CPP_SERVER = "llama.cpp.server"
PROXY = "proxy"
VLLM = "vllm"
+ SGLANG = "sglang"
# TODO, support more model type
diff --git a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
new file mode 100644
index 000000000..7c91372e8
--- /dev/null
+++ b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
@@ -0,0 +1,169 @@
+import asyncio
+import logging
+from typing import Any, AsyncIterator, Dict, List
+
+from dbgpt.core import (
+ ChatCompletionResponse,
+ ChatCompletionResponseChoice,
+ ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse,
+ ChatMessage,
+ DeltaMessage,
+ ModelMessage,
+ ModelMessageRole,
+)
+from dbgpt.model.parameter import ModelParameters
+
+logger = logging.getLogger(__name__)
+
+
+async def generate_stream(
+ model: Any,
+ tokenizer: Any,
+ params: Dict[str, Any],
+ model_messages: List[ModelMessage],
+ model_parameters: ModelParameters,
+) -> AsyncIterator[ChatCompletionStreamResponse]:
+ """Generate stream response using SGLang."""
+ try:
+ import sglang as sgl
+ except ImportError:
+ raise ImportError("Please install sglang first: pip install sglang")
+
+ # Message format convert
+ messages = []
+ for msg in model_messages:
+ role = msg.role
+ if role == ModelMessageRole.HUMAN:
+ role = "user"
+ elif role == ModelMessageRole.SYSTEM:
+ role = "system"
+ elif role == ModelMessageRole.AI:
+ role = "assistant"
+ else:
+ role = "user"
+
+ messages.append({"role": role, "content": msg.content})
+
+ # Model params set
+ temperature = model_parameters.temperature
+ top_p = model_parameters.top_p
+ max_tokens = model_parameters.max_new_tokens
+
+ # Create SGLang request
+ async def stream_generator():
+ # Use SGLang async API generate
+ state = sgl.RuntimeState()
+
+ @sgl.function
+ def chat(state, messages):
+ sgl.gen(
+ messages=messages,
+ temperature=temperature,
+ top_p=top_p,
+ max_tokens=max_tokens,
+ )
+
+ # Start task generate
+ task = model.submit_task(chat, state, messages)
+
+ # Fetch result
+ generated_text = ""
+ async for output in task.stream_output():
+ if hasattr(output, "text"):
+ new_text = output.text
+ delta = new_text[len(generated_text) :]
+ generated_text = new_text
+
+ # Create Stream reponse
+ choice = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(role="assistant", content=delta),
+ finish_reason=None,
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=params.get("id", "chatcmpl-default"),
+ model=params.get("model", "sglang-model"),
+ choices=[choice],
+ created=int(asyncio.get_event_loop().time()),
+ )
+ yield chunk
+
+ # Send complete signal
+ choice = ChatCompletionResponseStreamChoice(
+ index=0,
+ delta=DeltaMessage(role="assistant", content=""),
+ finish_reason="stop",
+ )
+ chunk = ChatCompletionStreamResponse(
+ id=params.get("id", "chatcmpl-default"),
+ model=params.get("model", "sglang-model"),
+ choices=[choice],
+ created=int(asyncio.get_event_loop().time()),
+ )
+ yield chunk
+
+ async for chunk in stream_generator():
+ yield chunk
+
+
+async def generate(
+ model: Any,
+ tokenizer: Any,
+ params: Dict[str, Any],
+ model_messages: List[ModelMessage],
+ model_parameters: ModelParameters,
+) -> ChatCompletionResponse:
+ """Generate completion using SGLang."""
+ try:
+ import sglang as sgl
+ except ImportError:
+ raise ImportError("Please install sglang first: pip install sglang")
+
+ # Convert format to SGlang
+ messages = []
+ for msg in model_messages:
+ role = msg.role
+ if role == ModelMessageRole.HUMAN:
+ role = "user"
+ elif role == ModelMessageRole.SYSTEM:
+ role = "system"
+ elif role == ModelMessageRole.AI:
+ role = "assistant"
+ else:
+ role = "user"
+
+ messages.append({"role": role, "content": msg.content})
+
+ temperature = model_parameters.temperature
+ top_p = model_parameters.top_p
+ max_tokens = model_parameters.max_new_tokens
+
+ state = sgl.RuntimeState()
+
+ @sgl.function
+ def chat(state, messages):
+ sgl.gen(
+ messages=messages,
+ temperature=temperature,
+ top_p=top_p,
+ max_tokens=max_tokens,
+ )
+
+ task = await model.submit_task(chat, state, messages)
+ result = await task.wait()
+
+ choice = ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role="assistant", content=result.text),
+ finish_reason="stop",
+ )
+
+ response = ChatCompletionResponse(
+ id=params.get("id", "chatcmpl-default"),
+ model=params.get("model", "sglang-model"),
+ choices=[choice],
+ created=int(asyncio.get_event_loop().time()),
+ )
+
+ return response