mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-08-25 11:29:29 +00:00
feat: add sglang config
This commit is contained in:
parent
7b43a039ac
commit
e037525dc9
@ -72,6 +72,10 @@ vllm = [
|
|||||||
# # https://github.com/sasha0552/pascal-pkgs-ci
|
# # https://github.com/sasha0552/pascal-pkgs-ci
|
||||||
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
||||||
# ]
|
# ]
|
||||||
|
sglang = [
|
||||||
|
# Just support GPU version on Linux
|
||||||
|
"sglang>=0.2.0; sys_platform == 'linux'",
|
||||||
|
]
|
||||||
quant_bnb = [
|
quant_bnb = [
|
||||||
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
|
"bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'",
|
||||||
"accelerate"
|
"accelerate"
|
||||||
|
@ -314,20 +314,20 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
|
|||||||
def load_from_params(self, params: SGlangDeployModelParameters):
|
def load_from_params(self, params: SGlangDeployModelParameters):
|
||||||
try:
|
try:
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.srt.managers.server import AsyncServerManager
|
from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("Please install sglang first: pip install sglang")
|
raise ImportError("Please install sglang first: pip install sglang")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f" Start SGLang AsyncServerManager with args: \
|
f" Start SGLang AsyncLLMEngine with args: \
|
||||||
{_get_dataclass_print_str(params)}"
|
{_get_dataclass_print_str(params)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
sglang_args_dict = params.to_sglang_params()
|
sglang_args_dict = params.to_sglang_params()
|
||||||
model_path = sglang_args_dict.pop("model")
|
model_path = sglang_args_dict.pop("model")
|
||||||
|
|
||||||
# 创建SGLang服务器配置
|
# Create sglang config args
|
||||||
server_config = sgl.RuntimeConfig(
|
server_config = sgl.ServerArgs(
|
||||||
model=model_path,
|
model=model_path,
|
||||||
tensor_parallel_size=params.tensor_parallel_size,
|
tensor_parallel_size=params.tensor_parallel_size,
|
||||||
max_model_len=params.max_model_len or 4096,
|
max_model_len=params.max_model_len or 4096,
|
||||||
@ -337,18 +337,9 @@ class SGLangModelAdapterWrapper(LLMModelAdapter):
|
|||||||
**sglang_args_dict.get("extras", {}),
|
**sglang_args_dict.get("extras", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建异步服务器管理器
|
# Create sglang engine
|
||||||
engine = AsyncServerManager(server_config)
|
engine = AsyncLLMEngine(server_config)
|
||||||
|
tokenizer = engine.tokenizer_manager.tokenizer
|
||||||
# 获取tokenizer
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_path,
|
|
||||||
trust_remote_code=params.trust_remote_code,
|
|
||||||
revision=params.tokenizer_revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
return engine, tokenizer
|
return engine, tokenizer
|
||||||
|
|
||||||
def support_async(self) -> bool:
|
def support_async(self) -> bool:
|
||||||
|
@ -1,169 +1,90 @@
|
|||||||
import asyncio
|
import os
|
||||||
import logging
|
from typing import Dict
|
||||||
from typing import Any, AsyncIterator, Dict, List
|
|
||||||
|
|
||||||
from dbgpt.core import (
|
from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine
|
||||||
ChatCompletionResponse,
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
ChatCompletionResponseChoice,
|
|
||||||
ChatCompletionResponseStreamChoice,
|
|
||||||
ChatCompletionStreamResponse,
|
|
||||||
ChatMessage,
|
|
||||||
DeltaMessage,
|
|
||||||
ModelMessage,
|
|
||||||
ModelMessageRole,
|
|
||||||
)
|
|
||||||
from dbgpt.model.parameter import ModelParameters
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
from dbgpt.core import ModelOutput
|
||||||
|
|
||||||
|
_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
async def generate_stream(
|
async def generate_stream(
|
||||||
model: Any,
|
model: AsyncLLMEngine, tokenizer, params: Dict, device: str, content_length: int
|
||||||
tokenizer: Any,
|
):
|
||||||
params: Dict[str, Any],
|
prompt = params["prompt"]
|
||||||
model_messages: List[ModelMessage],
|
temperature = float(params.get("temperature", 1.0))
|
||||||
model_parameters: ModelParameters,
|
top_p = float(params.get("top_p", 1.0))
|
||||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
top_k = params.get("top_k", -1.0)
|
||||||
"""Generate stream response using SGLang."""
|
frequency_penalty = float(params.get("frequency_penalty", 0.0))
|
||||||
try:
|
presence_penalty = float(params.get("presence_penalty", 0.0))
|
||||||
import sglang as sgl
|
max_new_tokens = int(params.get("max_new_tokens", 32768))
|
||||||
except ImportError:
|
stop_str = params.get("stop", None)
|
||||||
raise ImportError("Please install sglang first: pip install sglang")
|
stop_token_ids = params.get("stop_token_ids", None) or []
|
||||||
|
echo = params.get("echo", True)
|
||||||
|
|
||||||
# Message format convert
|
# Handle stop_str
|
||||||
messages = []
|
stop = []
|
||||||
for msg in model_messages:
|
if isinstance(stop_str, str) and stop_str != "":
|
||||||
role = msg.role
|
stop.append(stop_str)
|
||||||
if role == ModelMessageRole.HUMAN:
|
elif isinstance(stop_str, list) and stop_str != []:
|
||||||
role = "user"
|
stop.extend(stop_str)
|
||||||
elif role == ModelMessageRole.SYSTEM:
|
|
||||||
role = "system"
|
|
||||||
elif role == ModelMessageRole.AI:
|
|
||||||
role = "assistant"
|
|
||||||
else:
|
|
||||||
role = "user"
|
|
||||||
|
|
||||||
messages.append({"role": role, "content": msg.content})
|
for tid in stop_token_ids:
|
||||||
|
s = tokenizer.decode(tid)
|
||||||
|
if s != "":
|
||||||
|
stop.append(s)
|
||||||
|
|
||||||
# Model params set
|
# make sampling params for sgl.gen
|
||||||
temperature = model_parameters.temperature
|
top_p = max(top_p, 1e-5)
|
||||||
top_p = model_parameters.top_p
|
if temperature <= 1e-5:
|
||||||
max_tokens = model_parameters.max_new_tokens
|
top_p = 1.0
|
||||||
|
|
||||||
# Create SGLang request
|
gen_params = {
|
||||||
async def stream_generator():
|
"stop": list(stop),
|
||||||
# Use SGLang async API generate
|
"ignore_eos": False,
|
||||||
state = sgl.RuntimeState()
|
}
|
||||||
|
|
||||||
@sgl.function
|
prompt_token_ids = None
|
||||||
def chat(state, messages):
|
if _IS_BENCHMARK:
|
||||||
sgl.gen(
|
gen_params["stop"] = []
|
||||||
messages=messages,
|
gen_params["ignore_eos"] = True
|
||||||
temperature=temperature,
|
prompt_len = content_length - max_new_tokens - 2
|
||||||
top_p=top_p,
|
prompt_token_ids = tokenizer([prompt]).input_ids[0]
|
||||||
max_tokens=max_tokens,
|
prompt_token_ids = prompt_token_ids[-prompt_len:]
|
||||||
)
|
|
||||||
|
|
||||||
# Start task generate
|
sampling_params = SamplingParams(
|
||||||
task = model.submit_task(chat, state, messages)
|
n=1,
|
||||||
|
temperature=temperature,
|
||||||
# Fetch result
|
top_p=top_p,
|
||||||
generated_text = ""
|
max_tokens=max_new_tokens,
|
||||||
async for output in task.stream_output():
|
top_k=top_k,
|
||||||
if hasattr(output, "text"):
|
presence_penalty=presence_penalty,
|
||||||
new_text = output.text
|
frequency_penalty=frequency_penalty,
|
||||||
delta = new_text[len(generated_text) :]
|
**gen_params,
|
||||||
generated_text = new_text
|
|
||||||
|
|
||||||
# Create Stream reponse
|
|
||||||
choice = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(role="assistant", content=delta),
|
|
||||||
finish_reason=None,
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=params.get("id", "chatcmpl-default"),
|
|
||||||
model=params.get("model", "sglang-model"),
|
|
||||||
choices=[choice],
|
|
||||||
created=int(asyncio.get_event_loop().time()),
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
# Send complete signal
|
|
||||||
choice = ChatCompletionResponseStreamChoice(
|
|
||||||
index=0,
|
|
||||||
delta=DeltaMessage(role="assistant", content=""),
|
|
||||||
finish_reason="stop",
|
|
||||||
)
|
|
||||||
chunk = ChatCompletionStreamResponse(
|
|
||||||
id=params.get("id", "chatcmpl-default"),
|
|
||||||
model=params.get("model", "sglang-model"),
|
|
||||||
choices=[choice],
|
|
||||||
created=int(asyncio.get_event_loop().time()),
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async for chunk in stream_generator():
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
|
|
||||||
async def generate(
|
|
||||||
model: Any,
|
|
||||||
tokenizer: Any,
|
|
||||||
params: Dict[str, Any],
|
|
||||||
model_messages: List[ModelMessage],
|
|
||||||
model_parameters: ModelParameters,
|
|
||||||
) -> ChatCompletionResponse:
|
|
||||||
"""Generate completion using SGLang."""
|
|
||||||
try:
|
|
||||||
import sglang as sgl
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError("Please install sglang first: pip install sglang")
|
|
||||||
|
|
||||||
# Convert format to SGlang
|
|
||||||
messages = []
|
|
||||||
for msg in model_messages:
|
|
||||||
role = msg.role
|
|
||||||
if role == ModelMessageRole.HUMAN:
|
|
||||||
role = "user"
|
|
||||||
elif role == ModelMessageRole.SYSTEM:
|
|
||||||
role = "system"
|
|
||||||
elif role == ModelMessageRole.AI:
|
|
||||||
role = "assistant"
|
|
||||||
else:
|
|
||||||
role = "user"
|
|
||||||
|
|
||||||
messages.append({"role": role, "content": msg.content})
|
|
||||||
|
|
||||||
temperature = model_parameters.temperature
|
|
||||||
top_p = model_parameters.top_p
|
|
||||||
max_tokens = model_parameters.max_new_tokens
|
|
||||||
|
|
||||||
state = sgl.RuntimeState()
|
|
||||||
|
|
||||||
@sgl.function
|
|
||||||
def chat(state, messages):
|
|
||||||
sgl.gen(
|
|
||||||
messages=messages,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
task = await model.submit_task(chat, state, messages)
|
|
||||||
result = await task.wait()
|
|
||||||
|
|
||||||
choice = ChatCompletionResponseChoice(
|
|
||||||
index=0,
|
|
||||||
message=ChatMessage(role="assistant", content=result.text),
|
|
||||||
finish_reason="stop",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
results_generator = model.async_generate(prompt, sampling_params, stream=True)
|
||||||
id=params.get("id", "chatcmpl-default"),
|
usage = None
|
||||||
model=params.get("model", "sglang-model"),
|
finish_reason = None
|
||||||
choices=[choice],
|
async for request_output in results_generator:
|
||||||
created=int(asyncio.get_event_loop().time()),
|
prompt = request_output.prompt
|
||||||
)
|
if echo:
|
||||||
|
text_outputs = [prompt + output.text for output in request_output.outputs]
|
||||||
|
else:
|
||||||
|
text_outputs = [output.text for output in request_output.outputs]
|
||||||
|
text_outputs = " ".join(text_outputs)
|
||||||
|
|
||||||
return response
|
# Note: usage is not supported yet
|
||||||
|
prompt_tokens = len(request_output.prompt_token_ids)
|
||||||
|
completion_tokens = sum(
|
||||||
|
len(output.token_ids) for output in request_output.outputs
|
||||||
|
)
|
||||||
|
|
||||||
|
yield ModelOutput(
|
||||||
|
text=text_outputs,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
usage=usage,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user