mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-24 19:13:33 +00:00
feat: add sglang support
This commit is contained in:
parent
c68332be4c
commit
7b43a039ac
39
configs/dbgpt-local-sglang.toml
Normal file
39
configs/dbgpt-local-sglang.toml
Normal file
@ -0,0 +1,39 @@
|
||||
[system]
|
||||
# Load language from environment variable(It is set by the hook)
|
||||
language = "${env:DBGPT_LANG:-zh}"
|
||||
api_keys = []
|
||||
encrypt_key = "your_secret_key"
|
||||
|
||||
# Server Configurations
|
||||
[service.web]
|
||||
host = "0.0.0.0"
|
||||
port = 5670
|
||||
|
||||
[service.web.database]
|
||||
type = "sqlite"
|
||||
path = "pilot/meta_data/dbgpt.db"
|
||||
|
||||
[rag.storage]
|
||||
[rag.storage.vector]
|
||||
type = "chroma"
|
||||
persist_path = "pilot/data"
|
||||
|
||||
# Model Configurations
|
||||
[models]
|
||||
[[models.llms]]
|
||||
name = "DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
provider = "sglang"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
path = "models/DeepSeek-R1-Distill-Qwen-1.5B"
|
||||
# dtype = "float32"
|
||||
|
||||
[[models.embeddings]]
|
||||
name = "BAAI/bge-large-zh-v1.5"
|
||||
provider = "hf"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
path = "/data/models/bge-large-zh-v1.5"
|
||||
|
@ -97,6 +97,7 @@ This tutorial assumes that you can establish network communication with the depe
|
||||
{label: 'DeepSeek (proxy)', value: 'deepseek'},
|
||||
{label: 'GLM4 (local)', value: 'glm-4'},
|
||||
{label: 'VLLM (local)', value: 'vllm'},
|
||||
{label: 'SGLang (local)', value: 'sglang'},
|
||||
{label: 'LLAMA_CPP (local)', value: 'llama_cpp'},
|
||||
{label: 'Ollama (proxy)', value: 'ollama'},
|
||||
]}>
|
||||
@ -291,6 +292,54 @@ uv run dbgpt start webserver --config configs/dbgpt-local-vllm.toml
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="sglang" label="SGLang(local)">
|
||||
|
||||
```bash
|
||||
# Use uv to install dependencies needed for vllm
|
||||
# Install core dependencies and select desired extensions
|
||||
uv sync --all-packages \
|
||||
--extra "base" \
|
||||
--extra "hf" \
|
||||
--extra "cuda121" \
|
||||
--extra "sglang" \
|
||||
--extra "rag" \
|
||||
--extra "storage_chromadb" \
|
||||
--extra "quant_bnb" \
|
||||
--extra "dbgpts"
|
||||
```
|
||||
|
||||
### Run Webserver
|
||||
|
||||
To run DB-GPT with the local model. You can modify the `configs/dbgpt-local-sglang.toml` configuration file to specify the model path and other parameters.
|
||||
|
||||
```toml
|
||||
# Model Configurations
|
||||
[models]
|
||||
[[models.llms]]
|
||||
name = "THUDM/glm-4-9b-chat-hf"
|
||||
provider = "sglang"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
|
||||
[[models.embeddings]]
|
||||
name = "BAAI/bge-large-zh-v1.5"
|
||||
provider = "hf"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
```
|
||||
In the above configuration file, `[[models.llms]]` specifies the LLM model, and `[[models.embeddings]]` specifies the embedding model. If you not provide the `path` parameter, the model will be downloaded from the Hugging Face model hub according to the `name` parameter.
|
||||
|
||||
Then run the following command to start the webserver:
|
||||
|
||||
```bash
|
||||
uv run dbgpt start webserver --config configs/dbgpt-local-sglang.toml
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="llama_cpp" label="LLAMA_CPP(local)">
|
||||
|
||||
If you has a Nvidia GPU, you can enable the CUDA support by setting the environment variable `CMAKE_ARGS="-DGGML_CUDA=ON"`.
|
||||
|
@ -95,6 +95,9 @@ class I18N:
|
||||
"vllm_preset": "VLLM Local Mode",
|
||||
"vllm_desc": "Using VLLM framework to load local model, requires GPU environment",
|
||||
"vllm_note": "Requires local model path configuration",
|
||||
"sglang_preset": "SGLang local Mode",
|
||||
"sglang_desc": "Use SGlang framework to load local model, requires GPU environment",
|
||||
"sglang_note": "Requires local model path configuration",
|
||||
"llama_cpp_preset": "LLAMA_CPP Local Mode",
|
||||
"llama_cpp_desc": "Using LLAMA.cpp framework to load local model, can run on CPU but GPU recommended",
|
||||
"llama_cpp_note": 'Requires local model path configuration, for CUDA support set CMAKE_ARGS="-DGGML_CUDA=ON"',
|
||||
@ -175,6 +178,9 @@ class I18N:
|
||||
"vllm_preset": "VLLM 本地模式",
|
||||
"vllm_desc": "使用VLLM框架加载本地模型,需要GPU环境",
|
||||
"vllm_note": "需要配置本地模型路径",
|
||||
"sglang_preset": "SGLang 本地模式",
|
||||
"sglang_desc": "使用SGLang框架加载本地模型,需要GPU环境",
|
||||
"sglang_note": "需要配置本地模型路径",
|
||||
"llama_cpp_preset": "LLAMA_CPP 本地模式",
|
||||
"llama_cpp_desc": "使用LLAMA.cpp框架加载本地模型,CPU也可运行但推荐GPU",
|
||||
"llama_cpp_note": '需要配置本地模型路径,启用CUDA需设置CMAKE_ARGS="-DGGML_CUDA=ON"',
|
||||
@ -361,6 +367,21 @@ def get_deployment_presets():
|
||||
"description": i18n.get("vllm_desc"),
|
||||
"note": i18n.get("vllm_note"),
|
||||
},
|
||||
i18n.get("sglang_preset"): {
|
||||
"extras": [
|
||||
"base",
|
||||
"hf",
|
||||
"cuda121",
|
||||
"sglang",
|
||||
"rag",
|
||||
"storage_chromadb",
|
||||
"quant_bnb",
|
||||
"dbgpts",
|
||||
],
|
||||
"config": "configs/dbgpt-local-sglang.toml",
|
||||
"description": i18n.get("sgalang_desc"),
|
||||
"note": i18n.get("sglang_note"),
|
||||
},
|
||||
i18n.get("llama_cpp_preset"): {
|
||||
"extras": [
|
||||
"base",
|
||||
|
@ -40,6 +40,7 @@ class BaseDeployModelParameters(BaseParameters):
|
||||
"llama_cpp_server",
|
||||
"proxy/*",
|
||||
"vllm",
|
||||
"sglang",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
@ -40,6 +40,7 @@ def scan_model_providers():
|
||||
"hf_adapter",
|
||||
"llama_cpp_adapter",
|
||||
"llama_cpp_py_adapter",
|
||||
"sglang_adapter",
|
||||
],
|
||||
)
|
||||
config_llms = ScannerConfig(
|
||||
|
373
packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
Normal file
373
packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py
Normal file
@ -0,0 +1,373 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
|
||||
from dbgpt.model.adapter.template import ConversationAdapter, ConversationAdapterFactory
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.util.i18n_utils import _
|
||||
from dbgpt.util.parameter_utils import (
|
||||
_get_dataclass_print_str,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SGlangDeployModelParameters(LLMDeployModelParameters):
|
||||
"""SGlang deploy model parameters"""
|
||||
|
||||
provider: str = "sglang"
|
||||
|
||||
path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"order": -800,
|
||||
"help": _("The path of the model, if you want to deploy a local model."),
|
||||
},
|
||||
)
|
||||
|
||||
device: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"order": -700,
|
||||
"help": _(
|
||||
"The device to run the model, 'auto' means using the default device."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
concurrency: Optional[int] = field(
|
||||
default=100, metadata={"help": _("Model concurrency limit")}
|
||||
)
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
"""Get the real model path
|
||||
|
||||
If deploy model is not local, return None
|
||||
"""
|
||||
return self._resolve_root_path(self.path)
|
||||
|
||||
@property
|
||||
def real_device(self) -> Optional[str]:
|
||||
"""Get the real device"""
|
||||
|
||||
return self.device or super().real_device
|
||||
|
||||
def to_sglang_params(
|
||||
self, sglang_config_cls: Optional[Type] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert to sglang deploy model parameters"""
|
||||
|
||||
data = self.to_dict()
|
||||
model = data.get("path", None)
|
||||
if not model:
|
||||
model = data.get("name", None)
|
||||
if not model:
|
||||
raise ValueError(
|
||||
_("Model is required, please pecify the model path or name.")
|
||||
)
|
||||
|
||||
copy_data = data.copy()
|
||||
real_params = {}
|
||||
extra_params = copy_data.get("extras", {})
|
||||
if sglang_config_cls and is_dataclass(sglang_config_cls):
|
||||
for fd in fields(sglang_config_cls):
|
||||
if fd.name in copy_data:
|
||||
real_params[fd.name] = copy_data[fd.name]
|
||||
|
||||
else:
|
||||
for k, v in copy_data.items():
|
||||
if k in [
|
||||
"provider",
|
||||
"path",
|
||||
"name",
|
||||
"extras",
|
||||
"verbose",
|
||||
"backend",
|
||||
"prompt_template",
|
||||
"context_length",
|
||||
]:
|
||||
continue
|
||||
real_params[k] = v
|
||||
|
||||
real_params["model"] = model
|
||||
if extra_params and isinstance(extra_params, dict):
|
||||
real_params.update(extra_params)
|
||||
return real_params
|
||||
|
||||
trust_remote_code: Optional[bool] = field(
|
||||
default=True, metadata={"help": _("Trust remote code or not.")}
|
||||
)
|
||||
|
||||
download_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Directory to download and load the weights, "
|
||||
"default to the default cache dir of "
|
||||
"buggingface."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
load_format: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The format of the model weights to load.\n\n"
|
||||
'* "auto" will try to load the weights in the safetensors format '
|
||||
"and fall back to the pytorch bin format if safetensors format "
|
||||
"is not available.\n"
|
||||
'* "pt" will load the weights in the pytorch bin format.\n'
|
||||
'* "safetensors" will load the weights in the safetensors format.\n'
|
||||
'* "npcache" will load the weights in pytorch format and store '
|
||||
"a numpy cache to speed up the loading.\n"
|
||||
'* "dummy" will initialize the weights with random values, '
|
||||
"which is mainly for profiling.\n"
|
||||
'* "tensorizer" will load the weights using tensorizer from '
|
||||
"CoreWeave. See the Tensorize vLLM Model script in the Examples "
|
||||
"section for more information.\n"
|
||||
'* "runai_streamer" will load the Safetensors weights using Run:ai'
|
||||
"Model Streamer \n"
|
||||
'* "bitsandbytes" will load the weights using bitsandbytes '
|
||||
"quantization.\n"
|
||||
),
|
||||
"valid_values": [
|
||||
"auto",
|
||||
"pt",
|
||||
"safetensors",
|
||||
"npcache",
|
||||
"dummy",
|
||||
"tensorizer",
|
||||
"runai_streamer",
|
||||
"bitsandbytes",
|
||||
"sharded_state",
|
||||
"gguf",
|
||||
"mistral",
|
||||
],
|
||||
},
|
||||
)
|
||||
config_format: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": _(
|
||||
"The format of the model config to load.\n\n"
|
||||
'* "auto" will try to load the config in hf format '
|
||||
"if available else it will try to load in mistral format "
|
||||
),
|
||||
"valid_values": [
|
||||
"auto",
|
||||
"hf",
|
||||
"mistral",
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
dtype: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"help": _(
|
||||
"Data type for model weights and activations.\n\n"
|
||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||
"BF16 precision for BF16 models.\n"
|
||||
'* "half" for FP16.\n'
|
||||
'* "float16" is the same as "half".\n'
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.'
|
||||
),
|
||||
"valid_values": ["auto", "half", "float16", "bfloat16", "float", "float32"],
|
||||
},
|
||||
)
|
||||
|
||||
max_model_len: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Model context length. If unspecified, will be automatically derived "
|
||||
"from the model config."
|
||||
),
|
||||
},
|
||||
)
|
||||
tensor_parallel_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": _("Number of tensor parallel replicas."),
|
||||
},
|
||||
)
|
||||
max_num_seqs: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("Maximum number of sequences per iteration."),
|
||||
},
|
||||
)
|
||||
revision: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"The specific model version to use. It can be a branch "
|
||||
"name, a tag name, or a commit id. If unspecified, will use "
|
||||
"the default version."
|
||||
),
|
||||
},
|
||||
)
|
||||
tokenizer_revision: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Revision of the huggingface tokenizer to use. "
|
||||
"It can be a branch name, a tag name, or a commit id. "
|
||||
"If unspecified, will use the default version."
|
||||
),
|
||||
},
|
||||
)
|
||||
quantization: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _(
|
||||
"Method used to quantize the weights. If "
|
||||
"None, we first check the `quantization_config` "
|
||||
"attribute in the model config file. If that is "
|
||||
"None, we assume the model weights are not "
|
||||
"quantized and use `dtype` to determine the data "
|
||||
"type of the weights."
|
||||
),
|
||||
"valid_values": ["awq", "gptq", "int8"],
|
||||
},
|
||||
)
|
||||
gpu_memory_utilization: float = field(
|
||||
default=0.90,
|
||||
metadata={
|
||||
"help": _("The fraction of GPU memory to be used for the model."),
|
||||
},
|
||||
)
|
||||
extras: Optional[Dict] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": _("Extra parameters, it will be passed to the sglang engine.")
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SGLangModelAdapterWrapper(LLMModelAdapter):
|
||||
"""Wrapping sglang engine"""
|
||||
|
||||
def __init__(self, conv_factory: Optional[ConversationAdapterFactory] = None):
|
||||
if not conv_factory:
|
||||
from dbgpt.model.adapter.model_adapter import (
|
||||
DefaultConversationAdapterFactory,
|
||||
)
|
||||
|
||||
conv_factory = DefaultConversationAdapterFactory()
|
||||
|
||||
self.conv_factory = conv_factory
|
||||
|
||||
def new_adapter(self, **kwargs) -> "SGLangModelAdapterWrapper":
|
||||
new_obj = super().new_adapter(**kwargs)
|
||||
new_obj.conv_factory = self.conv_factory
|
||||
return new_obj # type: ignore
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.SGLANG
|
||||
|
||||
def model_param_class(
|
||||
self, model_type: str = None
|
||||
) -> Type[SGlangDeployModelParameters]:
|
||||
"""Get model parameters class."""
|
||||
return SGlangDeployModelParameters
|
||||
|
||||
def match(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
return provider == ModelType.SGLANG
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
convert_to_compatible_format: bool = False,
|
||||
) -> Optional[str]:
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is None")
|
||||
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
messages = self.transform_model_messages(
|
||||
messages, convert_to_compatible_format
|
||||
)
|
||||
logger.debug(f"The messages after transform: \n{messages}")
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
return None
|
||||
|
||||
def load_from_params(self, params: SGlangDeployModelParameters):
|
||||
try:
|
||||
import sglang as sgl
|
||||
from sglang.srt.managers.server import AsyncServerManager
|
||||
except ImportError:
|
||||
raise ImportError("Please install sglang first: pip install sglang")
|
||||
|
||||
logger.info(
|
||||
f" Start SGLang AsyncServerManager with args: \
|
||||
{_get_dataclass_print_str(params)}"
|
||||
)
|
||||
|
||||
sglang_args_dict = params.to_sglang_params()
|
||||
model_path = sglang_args_dict.pop("model")
|
||||
|
||||
# 创建SGLang服务器配置
|
||||
server_config = sgl.RuntimeConfig(
|
||||
model=model_path,
|
||||
tensor_parallel_size=params.tensor_parallel_size,
|
||||
max_model_len=params.max_model_len or 4096,
|
||||
dtype=params.dtype if params.dtype != "auto" else None,
|
||||
quantization=params.quantization,
|
||||
gpu_memory_utilization=params.gpu_memory_utilization,
|
||||
**sglang_args_dict.get("extras", {}),
|
||||
)
|
||||
|
||||
# 创建异步服务器管理器
|
||||
engine = AsyncServerManager(server_config)
|
||||
|
||||
# 获取tokenizer
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=params.trust_remote_code,
|
||||
revision=params.tokenizer_revision,
|
||||
)
|
||||
|
||||
return engine, tokenizer
|
||||
|
||||
def support_async(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_async_generate_stream_function(
|
||||
self, model, deploy_model_params: LLMDeployModelParameters
|
||||
):
|
||||
from dbgpt.model.llm.llm_out.sglang_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> ConversationAdapter:
|
||||
return self.conv_factory.get_by_model(model_name, model_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}.{}".format(self.__class__.__module__, self.__class__.__name__)
|
||||
|
||||
|
||||
register_model_adapter(SGLangModelAdapterWrapper, supported_models=COMMON_HF_MODELS)
|
@ -17,6 +17,7 @@ class ModelType:
|
||||
LLAMA_CPP_SERVER = "llama.cpp.server"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
SGLANG = "sglang"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
|
169
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
Normal file
169
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py
Normal file
@ -0,0 +1,169 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, List
|
||||
|
||||
from dbgpt.core import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatMessage,
|
||||
DeltaMessage,
|
||||
ModelMessage,
|
||||
ModelMessageRole,
|
||||
)
|
||||
from dbgpt.model.parameter import ModelParameters
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
model: Any,
|
||||
tokenizer: Any,
|
||||
params: Dict[str, Any],
|
||||
model_messages: List[ModelMessage],
|
||||
model_parameters: ModelParameters,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate stream response using SGLang."""
|
||||
try:
|
||||
import sglang as sgl
|
||||
except ImportError:
|
||||
raise ImportError("Please install sglang first: pip install sglang")
|
||||
|
||||
# Message format convert
|
||||
messages = []
|
||||
for msg in model_messages:
|
||||
role = msg.role
|
||||
if role == ModelMessageRole.HUMAN:
|
||||
role = "user"
|
||||
elif role == ModelMessageRole.SYSTEM:
|
||||
role = "system"
|
||||
elif role == ModelMessageRole.AI:
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
|
||||
messages.append({"role": role, "content": msg.content})
|
||||
|
||||
# Model params set
|
||||
temperature = model_parameters.temperature
|
||||
top_p = model_parameters.top_p
|
||||
max_tokens = model_parameters.max_new_tokens
|
||||
|
||||
# Create SGLang request
|
||||
async def stream_generator():
|
||||
# Use SGLang async API generate
|
||||
state = sgl.RuntimeState()
|
||||
|
||||
@sgl.function
|
||||
def chat(state, messages):
|
||||
sgl.gen(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Start task generate
|
||||
task = model.submit_task(chat, state, messages)
|
||||
|
||||
# Fetch result
|
||||
generated_text = ""
|
||||
async for output in task.stream_output():
|
||||
if hasattr(output, "text"):
|
||||
new_text = output.text
|
||||
delta = new_text[len(generated_text) :]
|
||||
generated_text = new_text
|
||||
|
||||
# Create Stream reponse
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=delta),
|
||||
finish_reason=None,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=params.get("id", "chatcmpl-default"),
|
||||
model=params.get("model", "sglang-model"),
|
||||
choices=[choice],
|
||||
created=int(asyncio.get_event_loop().time()),
|
||||
)
|
||||
yield chunk
|
||||
|
||||
# Send complete signal
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role="assistant", content=""),
|
||||
finish_reason="stop",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=params.get("id", "chatcmpl-default"),
|
||||
model=params.get("model", "sglang-model"),
|
||||
choices=[choice],
|
||||
created=int(asyncio.get_event_loop().time()),
|
||||
)
|
||||
yield chunk
|
||||
|
||||
async for chunk in stream_generator():
|
||||
yield chunk
|
||||
|
||||
|
||||
async def generate(
|
||||
model: Any,
|
||||
tokenizer: Any,
|
||||
params: Dict[str, Any],
|
||||
model_messages: List[ModelMessage],
|
||||
model_parameters: ModelParameters,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate completion using SGLang."""
|
||||
try:
|
||||
import sglang as sgl
|
||||
except ImportError:
|
||||
raise ImportError("Please install sglang first: pip install sglang")
|
||||
|
||||
# Convert format to SGlang
|
||||
messages = []
|
||||
for msg in model_messages:
|
||||
role = msg.role
|
||||
if role == ModelMessageRole.HUMAN:
|
||||
role = "user"
|
||||
elif role == ModelMessageRole.SYSTEM:
|
||||
role = "system"
|
||||
elif role == ModelMessageRole.AI:
|
||||
role = "assistant"
|
||||
else:
|
||||
role = "user"
|
||||
|
||||
messages.append({"role": role, "content": msg.content})
|
||||
|
||||
temperature = model_parameters.temperature
|
||||
top_p = model_parameters.top_p
|
||||
max_tokens = model_parameters.max_new_tokens
|
||||
|
||||
state = sgl.RuntimeState()
|
||||
|
||||
@sgl.function
|
||||
def chat(state, messages):
|
||||
sgl.gen(
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
task = await model.submit_task(chat, state, messages)
|
||||
result = await task.wait()
|
||||
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role="assistant", content=result.text),
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=params.get("id", "chatcmpl-default"),
|
||||
model=params.get("model", "sglang-model"),
|
||||
choices=[choice],
|
||||
created=int(asyncio.get_event_loop().time()),
|
||||
)
|
||||
|
||||
return response
|
Loading…
Reference in New Issue
Block a user