feat(model): Support MLX inference (#2781)

This commit is contained in:
Fangyin Cheng
2025-06-19 09:30:58 +08:00
committed by GitHub
parent 9084c6c19c
commit d9d4d4b6bc
12 changed files with 5047 additions and 4662 deletions

View File

@@ -37,6 +37,7 @@ def scan_model_providers():
base_class=LLMDeployModelParameters,
specific_files=[
"vllm_adapter",
"mlx_adapter",
"hf_adapter",
"llama_cpp_adapter",
"llama_cpp_py_adapter",

View File

@@ -270,14 +270,18 @@ class LLMModelAdapter(ABC):
):
return True
return (
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and (
"r1" in lower_model_name_or_path
or "reasoning" in lower_model_name_or_path
or "reasoner" in lower_model_name_or_path
(
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and (
"r1" in lower_model_name_or_path
or "reasoning" in lower_model_name_or_path
or "reasoner" in lower_model_name_or_path
)
)
) or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
or (lower_model_name_or_path and "qwen3" in lower_model_name_or_path)
)
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""

View File

@@ -0,0 +1,126 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type
from dbgpt.core import ModelMessage
from dbgpt.core.interface.parameter import LLMDeployModelParameters
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
from dbgpt.model.adapter.template import ConversationAdapter
from dbgpt.model.base import ModelType
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
@dataclass
class MLXDeployModelParameters(LLMDeployModelParameters):
"""Local deploy model parameters."""
provider: str = "mlx"
path: Optional[str] = field(
default=None,
metadata={
"order": -800,
"help": _("The path of the model, if you want to deploy a local model."),
},
)
device: Optional[str] = field(
default="auto",
metadata={
"order": -700,
"help": _(
"Device to run model. If None, the device is automatically determined"
),
},
)
concurrency: Optional[int] = field(
default=100, metadata={"help": _("Model concurrency limit")}
)
@property
def real_model_path(self) -> Optional[str]:
"""Get the real model path.
If deploy model is not local, return None.
"""
return self._resolve_root_path(self.path)
@property
def real_device(self) -> Optional[str]:
"""Get the real device."""
return self.device or super().real_device
class MLXModelAdapter(LLMModelAdapter):
def match(
self,
provider: str,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
) -> bool:
return provider == ModelType.MLX
def model_type(self) -> str:
return ModelType.MLX
def model_param_class(
self, model_type: str = None
) -> Type[MLXDeployModelParameters]:
return MLXDeployModelParameters
def get_default_conv_template(
self, model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
return None
def load_from_params(self, params: MLXDeployModelParameters):
"""Load model from parameters."""
try:
from mlx_lm import load
except ImportError:
logger.error(
"MLX model adapter requires mlx_lm package. "
"Please install it with `pip install mlx-lm`."
)
raise
model_path = params.real_model_path
model, tokenizer = load(model_path)
return model, tokenizer
def support_generate_function(self) -> bool:
return False
def get_generate_stream_function(
self, model, deploy_model_params: LLMDeployModelParameters
):
from dbgpt.model.llm.llm_out.mlx_llm import generate_stream
return generate_stream
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
if not tokenizer:
raise ValueError("tokenizer is is None")
if hasattr(tokenizer, "apply_chat_template"):
messages = self.transform_model_messages(
messages, convert_to_compatible_format
)
logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt
return None
register_model_adapter(MLXModelAdapter, supported_models=COMMON_HF_MODELS)

View File

@@ -17,6 +17,7 @@ class ModelType:
LLAMA_CPP_SERVER = "llama.cpp.server"
PROXY = "proxy"
VLLM = "vllm"
MLX = "mlx"
# TODO, support more model type

View File

@@ -0,0 +1,110 @@
from typing import Dict
import mlx.nn as nn
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from dbgpt.core import ModelOutput
from ...utils.llm_metrics import LLMPerformanceMonitor
from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN,
parse_chat_message,
)
def generate_stream(
model: nn.Module,
tokenizer: TokenizerWrapper,
params: Dict,
device: str,
context_len: int,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", 0)
max_new_tokens = int(params.get("max_new_tokens", 2048))
# echo = bool(params.get("echo", True))
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
reasoning_patterns = [
{"start": think_start_token, "end": think_end_token},
]
sampler = make_sampler(
temp=temperature,
top_p=top_p,
# min_p=min_p,
# min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
)
# Initialize the performance monitor with estimated token count
estimated_input_tokens = len(tokenizer.encode(prompt))
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
# Start measuring prefill phase
perf_monitor.start_prefill()
results_generator = stream_generate(
model, tokenizer, prompt=prompt, max_tokens=max_new_tokens, sampler=sampler
)
text = ""
is_first = True
for res in results_generator:
new_text = res.text
text += new_text
# The prompt processing tokens-per-second.
# prompt_tps = res.prompt_tps
# The number of tokens in the prompt.
prompt_tokens = res.prompt_tokens
# The number of generated tokens.
generation_tokens = res.generation_tokens
# The tokens-per-second for generation.
# generation_tps = res.generation_tps
# The peak memory used so far in GB.
# peak_memory = res.peak_memory
# "length", "stop" or `None`
finish_reason = res.finish_reason
if (
prompt.rstrip().endswith(think_start_token)
and is_reasoning_model
and is_first
):
text = think_start_token + "\n" + text
is_first = False
msg = parse_chat_message(
text,
extract_reasoning=is_reasoning_model,
reasoning_patterns=reasoning_patterns,
)
# If this is the first iteration, update the input token count
if perf_monitor.metrics.input_token_count != prompt_tokens:
perf_monitor.metrics.input_token_count = prompt_tokens
# Update performance metrics based on current token count
perf_metrics = perf_monitor.on_tokens_received(generation_tokens)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": generation_tokens,
"total_tokens": prompt_tokens + generation_tokens,
}
# Check if generation is complete
is_complete = finish_reason is not None
if is_complete:
perf_monitor.end_generation()
usage.update(perf_metrics)
yield ModelOutput.build(
msg.content,
msg.reasoning_content,
error_code=0,
usage=usage,
is_reasoning_model=is_reasoning_model,
)