feat(model): Support MLX inference (#2781)

This commit is contained in:
Fangyin Cheng 2025-06-19 09:30:58 +08:00 committed by GitHub
parent 9084c6c19c
commit d9d4d4b6bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 5047 additions and 4662 deletions

View File

@ -0,0 +1,36 @@
[system]
# Load language from environment variable(It is set by the hook)
language = "${env:DBGPT_LANG:-zh}"
api_keys = []
encrypt_key = "your_secret_key"
# Server Configurations
[service.web]
host = "0.0.0.0"
port = 5670
[service.web.database]
type = "sqlite"
path = "pilot/meta_data/dbgpt.db"
[rag.storage]
[rag.storage.vector]
type = "chroma"
persist_path = "pilot/data"
# Model Configurations
[models]
[[models.llms]]
name = "Qwen/Qwen3-0.6B-MLX-4bit"
provider = "mlx"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
# path = "the-model-path-in-the-local-file-system"
[[models.embeddings]]
name = "BAAI/bge-large-zh-v1.5"
provider = "hf"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# path = "the-model-path-in-the-local-file-system"

View File

@ -0,0 +1,43 @@
# MLX Inference
DB-GPT supports [MLX](https://github.com/ml-explore/mlx-lm) inference, a fast and easy-to-use LLM inference and service library.
## Install dependencies
`MLX` is an optional dependency in DB-GPT. You can install it by adding the extra `--extra "mlx"` when installing dependencies.
```bash
# Use uv to install dependencies needed for mlx
# Install core dependencies and select desired extensions
uv sync --all-packages \
--extra "base" \
--extra "hf" \
--extra "mlx" \
--extra "rag" \
--extra "storage_chromadb" \
--extra "quant_bnb" \
--extra "dbgpts"
```
## Modify configuration file
After installing the dependencies, you can modify your configuration file to use the `mlx` provider.
```toml
# Model Configurations
[models]
[[models.llms]]
name = "Qwen/Qwen3-0.6B-MLX-4bit"
provider = "mlx"
# If not provided, the model will be downloaded from the Hugging Face model hub
# uncomment the following line to specify the model path in the local file system
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
# path = "the-model-path-in-the-local-file-system"
```
### Step 3: Run the Model
You can run the model using the following command:
```bash
uv run dbgpt start webserver --config {your_config_file}
```

View File

@ -170,6 +170,10 @@ const sidebars = {
type: 'doc',
id: 'installation/advanced_usage/vLLM_inference',
},
{
type: 'doc',
id: 'installation/advanced_usage/mlx_inference',
},
{
type: 'doc',
id: 'installation/advanced_usage/Llamacpp_server',

File diff suppressed because it is too large Load Diff

View File

@ -68,6 +68,9 @@ vllm = [
# Just support GPU version on Linux
"vllm>=0.7.0; sys_platform == 'linux'",
]
mlx = [
"mlx-lm>=0.25.2; sys_platform == 'darwin'",
]
# vllm_pascal = [
# # https://github.com/sasha0552/pascal-pkgs-ci
# "vllm-pascal==0.7.2; sys_platform == 'linux'"

View File

@ -279,7 +279,9 @@ def load_config(config_file: str = None) -> ApplicationConfig:
from dbgpt.configs.model_config import ROOT_PATH as DBGPT_ROOT_PATH
if config_file is None:
config_file = os.path.join(DBGPT_ROOT_PATH, "configs", "dbgpt-siliconflow.toml")
config_file = os.path.join(
DBGPT_ROOT_PATH, "configs", "dbgpt-proxy-siliconflow.toml"
)
elif not os.path.isabs(config_file):
# If config_file is a relative path, make it relative to DBGPT_ROOT_PATH
config_file = os.path.join(DBGPT_ROOT_PATH, config_file)

View File

@ -37,6 +37,7 @@ def scan_model_providers():
base_class=LLMDeployModelParameters,
specific_files=[
"vllm_adapter",
"mlx_adapter",
"hf_adapter",
"llama_cpp_adapter",
"llama_cpp_py_adapter",

View File

@ -270,14 +270,18 @@ class LLMModelAdapter(ABC):
):
return True
return (
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and (
"r1" in lower_model_name_or_path
or "reasoning" in lower_model_name_or_path
or "reasoner" in lower_model_name_or_path
(
lower_model_name_or_path
and "deepseek" in lower_model_name_or_path
and (
"r1" in lower_model_name_or_path
or "reasoning" in lower_model_name_or_path
or "reasoner" in lower_model_name_or_path
)
)
) or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
or (lower_model_name_or_path and "qwen3" in lower_model_name_or_path)
)
def support_async(self) -> bool:
"""Whether the loaded model supports asynchronous calls"""

View File

@ -0,0 +1,126 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type
from dbgpt.core import ModelMessage
from dbgpt.core.interface.parameter import LLMDeployModelParameters
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
from dbgpt.model.adapter.template import ConversationAdapter
from dbgpt.model.base import ModelType
from dbgpt.util.i18n_utils import _
logger = logging.getLogger(__name__)
@dataclass
class MLXDeployModelParameters(LLMDeployModelParameters):
"""Local deploy model parameters."""
provider: str = "mlx"
path: Optional[str] = field(
default=None,
metadata={
"order": -800,
"help": _("The path of the model, if you want to deploy a local model."),
},
)
device: Optional[str] = field(
default="auto",
metadata={
"order": -700,
"help": _(
"Device to run model. If None, the device is automatically determined"
),
},
)
concurrency: Optional[int] = field(
default=100, metadata={"help": _("Model concurrency limit")}
)
@property
def real_model_path(self) -> Optional[str]:
"""Get the real model path.
If deploy model is not local, return None.
"""
return self._resolve_root_path(self.path)
@property
def real_device(self) -> Optional[str]:
"""Get the real device."""
return self.device or super().real_device
class MLXModelAdapter(LLMModelAdapter):
def match(
self,
provider: str,
model_name: Optional[str] = None,
model_path: Optional[str] = None,
) -> bool:
return provider == ModelType.MLX
def model_type(self) -> str:
return ModelType.MLX
def model_param_class(
self, model_type: str = None
) -> Type[MLXDeployModelParameters]:
return MLXDeployModelParameters
def get_default_conv_template(
self, model_name: str, model_path: str
) -> Optional[ConversationAdapter]:
return None
def load_from_params(self, params: MLXDeployModelParameters):
"""Load model from parameters."""
try:
from mlx_lm import load
except ImportError:
logger.error(
"MLX model adapter requires mlx_lm package. "
"Please install it with `pip install mlx-lm`."
)
raise
model_path = params.real_model_path
model, tokenizer = load(model_path)
return model, tokenizer
def support_generate_function(self) -> bool:
return False
def get_generate_stream_function(
self, model, deploy_model_params: LLMDeployModelParameters
):
from dbgpt.model.llm.llm_out.mlx_llm import generate_stream
return generate_stream
def get_str_prompt(
self,
params: Dict,
messages: List[ModelMessage],
tokenizer: Any,
prompt_template: str = None,
convert_to_compatible_format: bool = False,
) -> Optional[str]:
if not tokenizer:
raise ValueError("tokenizer is is None")
if hasattr(tokenizer, "apply_chat_template"):
messages = self.transform_model_messages(
messages, convert_to_compatible_format
)
logger.debug(f"The messages after transform: \n{messages}")
str_prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return str_prompt
return None
register_model_adapter(MLXModelAdapter, supported_models=COMMON_HF_MODELS)

View File

@ -17,6 +17,7 @@ class ModelType:
LLAMA_CPP_SERVER = "llama.cpp.server"
PROXY = "proxy"
VLLM = "vllm"
MLX = "mlx"
# TODO, support more model type

View File

@ -0,0 +1,110 @@
from typing import Dict
import mlx.nn as nn
from mlx_lm import stream_generate
from mlx_lm.sample_utils import make_sampler
from mlx_lm.tokenizer_utils import TokenizerWrapper
from dbgpt.core import ModelOutput
from ...utils.llm_metrics import LLMPerformanceMonitor
from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN,
parse_chat_message,
)
def generate_stream(
model: nn.Module,
tokenizer: TokenizerWrapper,
params: Dict,
device: str,
context_len: int,
):
prompt = params["prompt"]
temperature = float(params.get("temperature", 0))
top_p = float(params.get("top_p", 1.0))
top_k = params.get("top_k", 0)
max_new_tokens = int(params.get("max_new_tokens", 2048))
# echo = bool(params.get("echo", True))
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
reasoning_patterns = [
{"start": think_start_token, "end": think_end_token},
]
sampler = make_sampler(
temp=temperature,
top_p=top_p,
# min_p=min_p,
# min_tokens_to_keep=min_tokens_to_keep,
top_k=top_k,
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
)
# Initialize the performance monitor with estimated token count
estimated_input_tokens = len(tokenizer.encode(prompt))
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
# Start measuring prefill phase
perf_monitor.start_prefill()
results_generator = stream_generate(
model, tokenizer, prompt=prompt, max_tokens=max_new_tokens, sampler=sampler
)
text = ""
is_first = True
for res in results_generator:
new_text = res.text
text += new_text
# The prompt processing tokens-per-second.
# prompt_tps = res.prompt_tps
# The number of tokens in the prompt.
prompt_tokens = res.prompt_tokens
# The number of generated tokens.
generation_tokens = res.generation_tokens
# The tokens-per-second for generation.
# generation_tps = res.generation_tps
# The peak memory used so far in GB.
# peak_memory = res.peak_memory
# "length", "stop" or `None`
finish_reason = res.finish_reason
if (
prompt.rstrip().endswith(think_start_token)
and is_reasoning_model
and is_first
):
text = think_start_token + "\n" + text
is_first = False
msg = parse_chat_message(
text,
extract_reasoning=is_reasoning_model,
reasoning_patterns=reasoning_patterns,
)
# If this is the first iteration, update the input token count
if perf_monitor.metrics.input_token_count != prompt_tokens:
perf_monitor.metrics.input_token_count = prompt_tokens
# Update performance metrics based on current token count
perf_metrics = perf_monitor.on_tokens_received(generation_tokens)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": generation_tokens,
"total_tokens": prompt_tokens + generation_tokens,
}
# Check if generation is complete
is_complete = finish_reason is not None
if is_complete:
perf_monitor.end_generation()
usage.update(perf_metrics)
yield ModelOutput.build(
msg.content,
msg.reasoning_content,
error_code=0,
usage=usage,
is_reasoning_model=is_reasoning_model,
)

8821
uv.lock

File diff suppressed because one or more lines are too long