feat: add sglang config

This commit is contained in:
csunny 2025-03-30 09:42:25 +08:00
parent 7b43a039ac
commit e037525dc9
3 changed files with 87 additions and 171 deletions

View File

@ -72,6 +72,10 @@ vllm = [
# # https://github.com/sasha0552/pascal-pkgs-ci
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
# ]
sglang = [
# Just support GPU version on Linux
"sglang>=0.2.0; sys_platform == 'linux'",
]
quant_bnb = [
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
"accelerate"

View File

@ -314,20 +314,20 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
def load_from_params(self, params: SGlangDeployModelParameters):
try:
import sglang as sgl
from sglang.srt.managers.server import AsyncServerManager
from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
except ImportError:
raise ImportError("Please install sglang first: pip install sglang")
logger.info(
f" Start SGLang AsyncServerManager with args: \
f" Start SGLang AsyncLLMEngine with args: \
{_get_dataclass_print_str(params)}"
)
sglang_args_dict = params.to_sglang_params()
model_path = sglang_args_dict.pop("model")
# 创建SGLang服务器配置
server_config = sgl.RuntimeConfig(
# Create sglang config args
server_config = sgl.ServerArgs(
model=model_path,
tensor_parallel_size=params.tensor_parallel_size,
max_model_len=params.max_model_len or 4096,
@ -337,18 +337,9 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
**sglang_args_dict.get("extras", {}),
)
# 创建异步服务器管理器
engine = AsyncServerManager(server_config)
# 获取tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=params.trust_remote_code,
revision=params.tokenizer_revision,
)
# Create sglang engine
engine = AsyncLLMEngine(server_config)
tokenizer = engine.tokenizer_manager.tokenizer
return engine, tokenizer
def support_async(self) -> bool:

View File

@ -1,169 +1,90 @@
import asyncio
import logging
from typing import Any, AsyncIterator, Dict, List
import os
from typing import Dict
from dbgpt.core import (
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ModelMessage,
ModelMessageRole,
)
from dbgpt.model.parameter import ModelParameters
from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
from sglang.srt.sampling.sampling_params import SamplingParams
logger = logging.getLogger(__name__)
from dbgpt.core import ModelOutput
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
async def generate_stream(
model: Any,
tokenizer: Any,
params: Dict[str, Any],
model_messages: List[ModelMessage],
model_parameters: ModelParameters,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate stream response using SGLang."""
try:
import sglang as sgl
except ImportError:
raise ImportError("Please install sglang first: pip install sglang")
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, content_length: int
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", -1.0)
frequency_penalty = float(params.get("frequency_penalty", 0.0))
presence_penalty = float(params.get("presence_penalty", 0.0))
max_new_tokens = int(params.get("max_new_tokens", 32768))
stop_str = params.get("stop", None)
stop_token_ids = params.get("stop_token_ids", None) or []
echo = params.get("echo", True)
# Message format convert
messages = []
for msg in model_messages:
role = msg.role
if role == ModelMessageRole.HUMAN:
role = "user"
elif role == ModelMessageRole.SYSTEM:
role = "system"
elif role == ModelMessageRole.AI:
role = "assistant"
else:
role = "user"
# Handle stop_str
stop = []
if isinstance(stop_str, str) and stop_str != "":
stop.append(stop_str)
elif isinstance(stop_str, list) and stop_str != []:
stop.extend(stop_str)
messages.append({"role": role, "content": msg.content})
for tid in stop_token_ids:
s = tokenizer.decode(tid)
if s != "":
stop.append(s)
# Model params set
temperature = model_parameters.temperature
top_p = model_parameters.top_p
max_tokens = model_parameters.max_new_tokens
# make sampling params for sgl.gen
top_p = max(top_p, 1e-5)
if temperature <= 1e-5:
top_p = 1.0
# Create SGLang request
async def stream_generator():
# Use SGLang async API generate
state = sgl.RuntimeState()
gen_params = {
"stop": list(stop),
"ignore_eos": False,
}
@sgl.function
def chat(state, messages):
sgl.gen(
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
prompt_token_ids = None
if _IS_BENCHMARK:
gen_params["stop"] = []
gen_params["ignore_eos"] = True
prompt_len = content_length - max_new_tokens - 2
prompt_token_ids = tokenizer([prompt]).input_ids[0]
prompt_token_ids = prompt_token_ids[-prompt_len:]
# Start task generate
task = model.submit_task(chat, state, messages)
# Fetch result
generated_text = ""
async for output in task.stream_output():
if hasattr(output, "text"):
new_text = output.text
delta = new_text[len(generated_text) :]
generated_text = new_text
# Create Stream reponse
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
yield chunk
# Send complete signal
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=""),
finish_reason="stop",
)
chunk = ChatCompletionStreamResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
yield chunk
async for chunk in stream_generator():
yield chunk
async def generate(
model: Any,
tokenizer: Any,
params: Dict[str, Any],
model_messages: List[ModelMessage],
model_parameters: ModelParameters,
) -> ChatCompletionResponse:
"""Generate completion using SGLang."""
try:
import sglang as sgl
except ImportError:
raise ImportError("Please install sglang first: pip install sglang")
# Convert format to SGlang
messages = []
for msg in model_messages:
role = msg.role
if role == ModelMessageRole.HUMAN:
role = "user"
elif role == ModelMessageRole.SYSTEM:
role = "system"
elif role == ModelMessageRole.AI:
role = "assistant"
else:
role = "user"
messages.append({"role": role, "content": msg.content})
temperature = model_parameters.temperature
top_p = model_parameters.top_p
max_tokens = model_parameters.max_new_tokens
state = sgl.RuntimeState()
@sgl.function
def chat(state, messages):
sgl.gen(
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
)
task = await model.submit_task(chat, state, messages)
result = await task.wait()
choice = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=result.text),
finish_reason="stop",
sampling_params = SamplingParams(
n=1,
temperature=temperature,
top_p=top_p,
max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
**gen_params,
)
response = ChatCompletionResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
results_generator = model.async_generate(prompt, sampling_params, stream=True)
usage = None
finish_reason = None
async for request_output in results_generator:
prompt = request_output.prompt
if echo:
text_outputs = [prompt + output.text for output in request_output.outputs]
else:
text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
return response
# Note: usage is not supported yet
prompt_tokens = len(request_output.prompt_token_ids)
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
yield ModelOutput(
text=text_outputs,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
usage=usage,
finish_reason=finish_reason,
)