feat: add sglang config

This commit is contained in:
csunny 2025-03-30 09:42:25 +08:00
parent 7b43a039ac
commit e037525dc9
3 changed files with 87 additions and 171 deletions

View File

@ -72,6 +72,10 @@ vllm = [
# # https://github.com/sasha0552/pascal-pkgs-ci # # https://github.com/sasha0552/pascal-pkgs-ci
# "vllm-pascal==0.7.2; sys_platform == 'linux'" # "vllm-pascal==0.7.2; sys_platform == 'linux'"
# ] # ]
sglang = [
# Just support GPU version on Linux
"sglang>=0.2.0; sys_platform == 'linux'",
]
quant_bnb = [ quant_bnb = [
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'", "bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
"accelerate" "accelerate"

View File

@ -314,20 +314,20 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
def load_from_params(self, params: SGlangDeployModelParameters): def load_from_params(self, params: SGlangDeployModelParameters):
try: try:
import sglang as sgl import sglang as sgl
from sglang.srt.managers.server import AsyncServerManager from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
except ImportError: except ImportError:
raise ImportError("Please install sglang first: pip install sglang") raise ImportError("Please install sglang first: pip install sglang")
logger.info( logger.info(
f" Start SGLang AsyncServerManager with args: \ f" Start SGLang AsyncLLMEngine with args: \
{_get_dataclass_print_str(params)}" {_get_dataclass_print_str(params)}"
) )
sglang_args_dict = params.to_sglang_params() sglang_args_dict = params.to_sglang_params()
model_path = sglang_args_dict.pop("model") model_path = sglang_args_dict.pop("model")
# 创建SGLang服务器配置 # Create sglang config args
server_config = sgl.RuntimeConfig( server_config = sgl.ServerArgs(
model=model_path, model=model_path,
tensor_parallel_size=params.tensor_parallel_size, tensor_parallel_size=params.tensor_parallel_size,
max_model_len=params.max_model_len or 4096, max_model_len=params.max_model_len or 4096,
@ -337,18 +337,9 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
**sglang_args_dict.get("extras", {}), **sglang_args_dict.get("extras", {}),
) )
# 创建异步服务器管理器 # Create sglang engine
engine = AsyncServerManager(server_config) engine = AsyncLLMEngine(server_config)
tokenizer = engine.tokenizer_manager.tokenizer
# 获取tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=params.trust_remote_code,
revision=params.tokenizer_revision,
)
return engine, tokenizer return engine, tokenizer
def support_async(self) -> bool: def support_async(self) -> bool:

View File

@ -1,169 +1,90 @@
import asyncio import os
import logging from typing import Dict
from typing import Any, AsyncIterator, Dict, List
from dbgpt.core import ( from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
ChatCompletionResponse, from sglang.srt.sampling.sampling_params import SamplingParams
ChatCompletionResponseChoice,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
DeltaMessage,
ModelMessage,
ModelMessageRole,
)
from dbgpt.model.parameter import ModelParameters
logger = logging.getLogger(__name__) from dbgpt.core import ModelOutput
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
async def generate_stream( async def generate_stream(
model: Any, model: AsyncLLMEngine, tokenizer, params: Dict, device: str, content_length: int
tokenizer: Any, ):
params: Dict[str, Any], prompt = params["prompt"]
model_messages: List[ModelMessage], temperature = float(params.get("temperature", 1.0))
model_parameters: ModelParameters, top_p = float(params.get("top_p", 1.0))
) -> AsyncIterator[ChatCompletionStreamResponse]: top_k = params.get("top_k", -1.0)
"""Generate stream response using SGLang.""" frequency_penalty = float(params.get("frequency_penalty", 0.0))
try: presence_penalty = float(params.get("presence_penalty", 0.0))
import sglang as sgl max_new_tokens = int(params.get("max_new_tokens", 32768))
except ImportError: stop_str = params.get("stop", None)
raise ImportError("Please install sglang first: pip install sglang") stop_token_ids = params.get("stop_token_ids", None) or []
echo = params.get("echo", True)
# Message format convert # Handle stop_str
messages = [] stop = []
for msg in model_messages: if isinstance(stop_str, str) and stop_str != "":
role = msg.role stop.append(stop_str)
if role == ModelMessageRole.HUMAN: elif isinstance(stop_str, list) and stop_str != []:
role = "user" stop.extend(stop_str)
elif role == ModelMessageRole.SYSTEM:
role = "system"
elif role == ModelMessageRole.AI:
role = "assistant"
else:
role = "user"
messages.append({"role": role, "content": msg.content}) for tid in stop_token_ids:
s = tokenizer.decode(tid)
if s != "":
stop.append(s)
# Model params set # make sampling params for sgl.gen
temperature = model_parameters.temperature top_p = max(top_p, 1e-5)
top_p = model_parameters.top_p if temperature <= 1e-5:
max_tokens = model_parameters.max_new_tokens top_p = 1.0
# Create SGLang request gen_params = {
async def stream_generator(): "stop": list(stop),
# Use SGLang async API generate "ignore_eos": False,
state = sgl.RuntimeState() }
@sgl.function prompt_token_ids = None
def chat(state, messages): if _IS_BENCHMARK:
sgl.gen( gen_params["stop"] = []
messages=messages, gen_params["ignore_eos"] = True
prompt_len = content_length - max_new_tokens - 2
prompt_token_ids = tokenizer([prompt]).input_ids[0]
prompt_token_ids = prompt_token_ids[-prompt_len:]
sampling_params = SamplingParams(
n=1,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_tokens=max_tokens, max_tokens=max_new_tokens,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
**gen_params,
) )
# Start task generate results_generator = model.async_generate(prompt, sampling_params, stream=True)
task = model.submit_task(chat, state, messages) usage = None
finish_reason = None
# Fetch result async for request_output in results_generator:
generated_text = "" prompt = request_output.prompt
async for output in task.stream_output(): if echo:
if hasattr(output, "text"): text_outputs = [prompt + output.text for output in request_output.outputs]
new_text = output.text
delta = new_text[len(generated_text) :]
generated_text = new_text
# Create Stream reponse
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=delta),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
yield chunk
# Send complete signal
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant", content=""),
finish_reason="stop",
)
chunk = ChatCompletionStreamResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
yield chunk
async for chunk in stream_generator():
yield chunk
async def generate(
model: Any,
tokenizer: Any,
params: Dict[str, Any],
model_messages: List[ModelMessage],
model_parameters: ModelParameters,
) -> ChatCompletionResponse:
"""Generate completion using SGLang."""
try:
import sglang as sgl
except ImportError:
raise ImportError("Please install sglang first: pip install sglang")
# Convert format to SGlang
messages = []
for msg in model_messages:
role = msg.role
if role == ModelMessageRole.HUMAN:
role = "user"
elif role == ModelMessageRole.SYSTEM:
role = "system"
elif role == ModelMessageRole.AI:
role = "assistant"
else: else:
role = "user" text_outputs = [output.text for output in request_output.outputs]
text_outputs = " ".join(text_outputs)
messages.append({"role": role, "content": msg.content}) # Note: usage is not supported yet
prompt_tokens = len(request_output.prompt_token_ids)
temperature = model_parameters.temperature completion_tokens = sum(
top_p = model_parameters.top_p len(output.token_ids) for output in request_output.outputs
max_tokens = model_parameters.max_new_tokens
state = sgl.RuntimeState()
@sgl.function
def chat(state, messages):
sgl.gen(
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
) )
task = await model.submit_task(chat, state, messages) yield ModelOutput(
result = await task.wait() text=text_outputs,
prompt_tokens=prompt_tokens,
choice = ChatCompletionResponseChoice( completion_tokens=completion_tokens,
index=0, usage=usage,
message=ChatMessage(role="assistant", content=result.text), finish_reason=finish_reason,
finish_reason="stop",
) )
response = ChatCompletionResponse(
id=params.get("id", "chatcmpl-default"),
model=params.get("model", "sglang-model"),
choices=[choice],
created=int(asyncio.get_event_loop().time()),
)
return response