diff --git a/packages/dbgpt-accelerator/dbgpt-acc-auto/pyproject.toml b/packages/dbgpt-accelerator/dbgpt-acc-auto/pyproject.toml index 38505a241..575cce298 100644 --- a/packages/dbgpt-accelerator/dbgpt-acc-auto/pyproject.toml +++ b/packages/dbgpt-accelerator/dbgpt-acc-auto/pyproject.toml @@ -72,6 +72,10 @@ vllm = [ # # https://github.com/sasha0552/pascal-pkgs-ci # "vllm-pascal==0.7.2; sys_platform == 'linux'" # ] +sglang = [ + # Just support GPU version on Linux + "sglang>=0.2.0; sys_platform == 'linux'", +] quant_bnb = [ "bitsandbytes>=0.39.0; sys_platform == 'win32' or sys_platform == 'linux'", "accelerate" diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py b/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py index 27b104668..d36cd1ab4 100644 --- a/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py +++ b/packages/dbgpt-core/src/dbgpt/model/adapter/sglang_adapter.py @@ -314,20 +314,20 @@ class SGLangModelAdapterWrapper(LLMModelAdapter): def load_from_params(self, params: SGlangDeployModelParameters): try: import sglang as sgl - from sglang.srt.managers.server import AsyncServerManager + from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine except ImportError: raise ImportError("Please install sglang first: pip install sglang") logger.info( - f" Start SGLang AsyncServerManager with args: \ + f" Start SGLang AsyncLLMEngine with args: \ {_get_dataclass_print_str(params)}" ) sglang_args_dict = params.to_sglang_params() model_path = sglang_args_dict.pop("model") - # 创建SGLang服务器配置 - server_config = sgl.RuntimeConfig( + # Create sglang config args + server_config = sgl.ServerArgs( model=model_path, tensor_parallel_size=params.tensor_parallel_size, max_model_len=params.max_model_len or 4096, @@ -337,18 +337,9 @@ class SGLangModelAdapterWrapper(LLMModelAdapter): **sglang_args_dict.get("extras", {}), ) - # 创建异步服务器管理器 - engine = AsyncServerManager(server_config) - - # 获取tokenizer - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained( - model_path, - trust_remote_code=params.trust_remote_code, - revision=params.tokenizer_revision, - ) - + # Create sglang engine + engine = AsyncLLMEngine(server_config) + tokenizer = engine.tokenizer_manager.tokenizer return engine, tokenizer def support_async(self) -> bool: diff --git a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py index 7c91372e8..92f1c8667 100644 --- a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py +++ b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/sglang_llm.py @@ -1,169 +1,90 @@ -import asyncio -import logging -from typing import Any, AsyncIterator, Dict, List +import os +from typing import Dict -from dbgpt.core import ( - ChatCompletionResponse, - ChatCompletionResponseChoice, - ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, - ChatMessage, - DeltaMessage, - ModelMessage, - ModelMessageRole, -) -from dbgpt.model.parameter import ModelParameters +from sglang.srt.entrypoints.engine import Engine as AsyncLLMEngine +from sglang.srt.sampling.sampling_params import SamplingParams -logger = logging.getLogger(__name__) +from dbgpt.core import ModelOutput + +_IS_BENCHMARK = os.getenv("DB_GPT_MODEL_BENCHMARK", "False").lower() == "true" async def generate_stream( - model: Any, - tokenizer: Any, - params: Dict[str, Any], - model_messages: List[ModelMessage], - model_parameters: ModelParameters, -) -> AsyncIterator[ChatCompletionStreamResponse]: - """Generate stream response using SGLang.""" - try: - import sglang as sgl - except ImportError: - raise ImportError("Please install sglang first: pip install sglang") + model: AsyncLLMEngine, tokenizer, params: Dict, device: str, content_length: int +): + prompt = params["prompt"] + temperature = float(params.get("temperature", 1.0)) + top_p = float(params.get("top_p", 1.0)) + top_k = params.get("top_k", -1.0) + frequency_penalty = float(params.get("frequency_penalty", 0.0)) + presence_penalty = float(params.get("presence_penalty", 0.0)) + max_new_tokens = int(params.get("max_new_tokens", 32768)) + stop_str = params.get("stop", None) + stop_token_ids = params.get("stop_token_ids", None) or [] + echo = params.get("echo", True) - # Message format convert - messages = [] - for msg in model_messages: - role = msg.role - if role == ModelMessageRole.HUMAN: - role = "user" - elif role == ModelMessageRole.SYSTEM: - role = "system" - elif role == ModelMessageRole.AI: - role = "assistant" - else: - role = "user" + # Handle stop_str + stop = [] + if isinstance(stop_str, str) and stop_str != "": + stop.append(stop_str) + elif isinstance(stop_str, list) and stop_str != []: + stop.extend(stop_str) - messages.append({"role": role, "content": msg.content}) + for tid in stop_token_ids: + s = tokenizer.decode(tid) + if s != "": + stop.append(s) - # Model params set - temperature = model_parameters.temperature - top_p = model_parameters.top_p - max_tokens = model_parameters.max_new_tokens + # make sampling params for sgl.gen + top_p = max(top_p, 1e-5) + if temperature <= 1e-5: + top_p = 1.0 - # Create SGLang request - async def stream_generator(): - # Use SGLang async API generate - state = sgl.RuntimeState() + gen_params = { + "stop": list(stop), + "ignore_eos": False, + } - @sgl.function - def chat(state, messages): - sgl.gen( - messages=messages, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - ) + prompt_token_ids = None + if _IS_BENCHMARK: + gen_params["stop"] = [] + gen_params["ignore_eos"] = True + prompt_len = content_length - max_new_tokens - 2 + prompt_token_ids = tokenizer([prompt]).input_ids[0] + prompt_token_ids = prompt_token_ids[-prompt_len:] - # Start task generate - task = model.submit_task(chat, state, messages) - - # Fetch result - generated_text = "" - async for output in task.stream_output(): - if hasattr(output, "text"): - new_text = output.text - delta = new_text[len(generated_text) :] - generated_text = new_text - - # Create Stream reponse - choice = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant", content=delta), - finish_reason=None, - ) - chunk = ChatCompletionStreamResponse( - id=params.get("id", "chatcmpl-default"), - model=params.get("model", "sglang-model"), - choices=[choice], - created=int(asyncio.get_event_loop().time()), - ) - yield chunk - - # Send complete signal - choice = ChatCompletionResponseStreamChoice( - index=0, - delta=DeltaMessage(role="assistant", content=""), - finish_reason="stop", - ) - chunk = ChatCompletionStreamResponse( - id=params.get("id", "chatcmpl-default"), - model=params.get("model", "sglang-model"), - choices=[choice], - created=int(asyncio.get_event_loop().time()), - ) - yield chunk - - async for chunk in stream_generator(): - yield chunk - - -async def generate( - model: Any, - tokenizer: Any, - params: Dict[str, Any], - model_messages: List[ModelMessage], - model_parameters: ModelParameters, -) -> ChatCompletionResponse: - """Generate completion using SGLang.""" - try: - import sglang as sgl - except ImportError: - raise ImportError("Please install sglang first: pip install sglang") - - # Convert format to SGlang - messages = [] - for msg in model_messages: - role = msg.role - if role == ModelMessageRole.HUMAN: - role = "user" - elif role == ModelMessageRole.SYSTEM: - role = "system" - elif role == ModelMessageRole.AI: - role = "assistant" - else: - role = "user" - - messages.append({"role": role, "content": msg.content}) - - temperature = model_parameters.temperature - top_p = model_parameters.top_p - max_tokens = model_parameters.max_new_tokens - - state = sgl.RuntimeState() - - @sgl.function - def chat(state, messages): - sgl.gen( - messages=messages, - temperature=temperature, - top_p=top_p, - max_tokens=max_tokens, - ) - - task = await model.submit_task(chat, state, messages) - result = await task.wait() - - choice = ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role="assistant", content=result.text), - finish_reason="stop", + sampling_params = SamplingParams( + n=1, + temperature=temperature, + top_p=top_p, + max_tokens=max_new_tokens, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + **gen_params, ) - response = ChatCompletionResponse( - id=params.get("id", "chatcmpl-default"), - model=params.get("model", "sglang-model"), - choices=[choice], - created=int(asyncio.get_event_loop().time()), - ) + results_generator = model.async_generate(prompt, sampling_params, stream=True) + usage = None + finish_reason = None + async for request_output in results_generator: + prompt = request_output.prompt + if echo: + text_outputs = [prompt + output.text for output in request_output.outputs] + else: + text_outputs = [output.text for output in request_output.outputs] + text_outputs = " ".join(text_outputs) - return response + # Note: usage is not supported yet + prompt_tokens = len(request_output.prompt_token_ids) + completion_tokens = sum( + len(output.token_ids) for output in request_output.outputs + ) + + yield ModelOutput( + text=text_outputs, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + usage=usage, + finish_reason=finish_reason, + )