mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-22 03:41:43 +00:00
feat(model): Support MLX inference (#2781)
This commit is contained in:
parent
9084c6c19c
commit
d9d4d4b6bc
36
configs/dbgpt-local-mlx.toml
Normal file
36
configs/dbgpt-local-mlx.toml
Normal file
@ -0,0 +1,36 @@
|
||||
[system]
|
||||
# Load language from environment variable(It is set by the hook)
|
||||
language = "${env:DBGPT_LANG:-zh}"
|
||||
api_keys = []
|
||||
encrypt_key = "your_secret_key"
|
||||
|
||||
# Server Configurations
|
||||
[service.web]
|
||||
host = "0.0.0.0"
|
||||
port = 5670
|
||||
|
||||
[service.web.database]
|
||||
type = "sqlite"
|
||||
path = "pilot/meta_data/dbgpt.db"
|
||||
|
||||
[rag.storage]
|
||||
[rag.storage.vector]
|
||||
type = "chroma"
|
||||
persist_path = "pilot/data"
|
||||
|
||||
# Model Configurations
|
||||
[models]
|
||||
[[models.llms]]
|
||||
name = "Qwen/Qwen3-0.6B-MLX-4bit"
|
||||
provider = "mlx"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
|
||||
[[models.embeddings]]
|
||||
name = "BAAI/bge-large-zh-v1.5"
|
||||
provider = "hf"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# path = "the-model-path-in-the-local-file-system"
|
43
docs/docs/installation/advanced_usage/mlx_inference.md
Normal file
43
docs/docs/installation/advanced_usage/mlx_inference.md
Normal file
@ -0,0 +1,43 @@
|
||||
# MLX Inference
|
||||
DB-GPT supports [MLX](https://github.com/ml-explore/mlx-lm) inference, a fast and easy-to-use LLM inference and service library.
|
||||
|
||||
## Install dependencies
|
||||
|
||||
`MLX` is an optional dependency in DB-GPT. You can install it by adding the extra `--extra "mlx"` when installing dependencies.
|
||||
|
||||
```bash
|
||||
# Use uv to install dependencies needed for mlx
|
||||
# Install core dependencies and select desired extensions
|
||||
uv sync --all-packages \
|
||||
--extra "base" \
|
||||
--extra "hf" \
|
||||
--extra "mlx" \
|
||||
--extra "rag" \
|
||||
--extra "storage_chromadb" \
|
||||
--extra "quant_bnb" \
|
||||
--extra "dbgpts"
|
||||
```
|
||||
|
||||
## Modify configuration file
|
||||
|
||||
After installing the dependencies, you can modify your configuration file to use the `mlx` provider.
|
||||
|
||||
```toml
|
||||
# Model Configurations
|
||||
[models]
|
||||
[[models.llms]]
|
||||
name = "Qwen/Qwen3-0.6B-MLX-4bit"
|
||||
provider = "mlx"
|
||||
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||
# uncomment the following line to specify the model path in the local file system
|
||||
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
|
||||
# path = "the-model-path-in-the-local-file-system"
|
||||
```
|
||||
|
||||
### Step 3: Run the Model
|
||||
|
||||
You can run the model using the following command:
|
||||
|
||||
```bash
|
||||
uv run dbgpt start webserver --config {your_config_file}
|
||||
```
|
@ -170,6 +170,10 @@ const sidebars = {
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/vLLM_inference',
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/mlx_inference',
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
id: 'installation/advanced_usage/Llamacpp_server',
|
||||
|
542
docs/yarn.lock
542
docs/yarn.lock
File diff suppressed because it is too large
Load Diff
@ -68,6 +68,9 @@ vllm = [
|
||||
# Just support GPU version on Linux
|
||||
"vllm>=0.7.0; sys_platform == 'linux'",
|
||||
]
|
||||
mlx = [
|
||||
"mlx-lm>=0.25.2; sys_platform == 'darwin'",
|
||||
]
|
||||
# vllm_pascal = [
|
||||
# # https://github.com/sasha0552/pascal-pkgs-ci
|
||||
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
||||
|
@ -279,7 +279,9 @@ def load_config(config_file: str = None) -> ApplicationConfig:
|
||||
from dbgpt.configs.model_config import ROOT_PATH as DBGPT_ROOT_PATH
|
||||
|
||||
if config_file is None:
|
||||
config_file = os.path.join(DBGPT_ROOT_PATH, "configs", "dbgpt-siliconflow.toml")
|
||||
config_file = os.path.join(
|
||||
DBGPT_ROOT_PATH, "configs", "dbgpt-proxy-siliconflow.toml"
|
||||
)
|
||||
elif not os.path.isabs(config_file):
|
||||
# If config_file is a relative path, make it relative to DBGPT_ROOT_PATH
|
||||
config_file = os.path.join(DBGPT_ROOT_PATH, config_file)
|
||||
|
@ -37,6 +37,7 @@ def scan_model_providers():
|
||||
base_class=LLMDeployModelParameters,
|
||||
specific_files=[
|
||||
"vllm_adapter",
|
||||
"mlx_adapter",
|
||||
"hf_adapter",
|
||||
"llama_cpp_adapter",
|
||||
"llama_cpp_py_adapter",
|
||||
|
@ -270,14 +270,18 @@ class LLMModelAdapter(ABC):
|
||||
):
|
||||
return True
|
||||
return (
|
||||
lower_model_name_or_path
|
||||
and "deepseek" in lower_model_name_or_path
|
||||
and (
|
||||
"r1" in lower_model_name_or_path
|
||||
or "reasoning" in lower_model_name_or_path
|
||||
or "reasoner" in lower_model_name_or_path
|
||||
(
|
||||
lower_model_name_or_path
|
||||
and "deepseek" in lower_model_name_or_path
|
||||
and (
|
||||
"r1" in lower_model_name_or_path
|
||||
or "reasoning" in lower_model_name_or_path
|
||||
or "reasoner" in lower_model_name_or_path
|
||||
)
|
||||
)
|
||||
) or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
||||
or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
||||
or (lower_model_name_or_path and "qwen3" in lower_model_name_or_path)
|
||||
)
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
|
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
|
||||
from dbgpt.model.adapter.template import ConversationAdapter
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLXDeployModelParameters(LLMDeployModelParameters):
|
||||
"""Local deploy model parameters."""
|
||||
|
||||
provider: str = "mlx"
|
||||
|
||||
path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"order": -800,
|
||||
"help": _("The path of the model, if you want to deploy a local model."),
|
||||
},
|
||||
)
|
||||
device: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"order": -700,
|
||||
"help": _(
|
||||
"Device to run model. If None, the device is automatically determined"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
concurrency: Optional[int] = field(
|
||||
default=100, metadata={"help": _("Model concurrency limit")}
|
||||
)
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
"""Get the real model path.
|
||||
|
||||
If deploy model is not local, return None.
|
||||
"""
|
||||
return self._resolve_root_path(self.path)
|
||||
|
||||
@property
|
||||
def real_device(self) -> Optional[str]:
|
||||
"""Get the real device."""
|
||||
return self.device or super().real_device
|
||||
|
||||
|
||||
class MLXModelAdapter(LLMModelAdapter):
|
||||
def match(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
return provider == ModelType.MLX
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.MLX
|
||||
|
||||
def model_param_class(
|
||||
self, model_type: str = None
|
||||
) -> Type[MLXDeployModelParameters]:
|
||||
return MLXDeployModelParameters
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
return None
|
||||
|
||||
def load_from_params(self, params: MLXDeployModelParameters):
|
||||
"""Load model from parameters."""
|
||||
try:
|
||||
from mlx_lm import load
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"MLX model adapter requires mlx_lm package. "
|
||||
"Please install it with `pip install mlx-lm`."
|
||||
)
|
||||
raise
|
||||
|
||||
model_path = params.real_model_path
|
||||
model, tokenizer = load(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(
|
||||
self, model, deploy_model_params: LLMDeployModelParameters
|
||||
):
|
||||
from dbgpt.model.llm.llm_out.mlx_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
convert_to_compatible_format: bool = False,
|
||||
) -> Optional[str]:
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
messages = self.transform_model_messages(
|
||||
messages, convert_to_compatible_format
|
||||
)
|
||||
logger.debug(f"The messages after transform: \n{messages}")
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
return None
|
||||
|
||||
|
||||
register_model_adapter(MLXModelAdapter, supported_models=COMMON_HF_MODELS)
|
@ -17,6 +17,7 @@ class ModelType:
|
||||
LLAMA_CPP_SERVER = "llama.cpp.server"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
MLX = "mlx"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
|
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
@ -0,0 +1,110 @@
|
||||
from typing import Dict
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.llm_metrics import LLMPerformanceMonitor
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_END_TOKEN,
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
parse_chat_message,
|
||||
)
|
||||
|
||||
|
||||
def generate_stream(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = params.get("top_k", 0)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
# echo = bool(params.get("echo", True))
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
|
||||
reasoning_patterns = [
|
||||
{"start": think_start_token, "end": think_end_token},
|
||||
]
|
||||
sampler = make_sampler(
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
# min_p=min_p,
|
||||
# min_tokens_to_keep=min_tokens_to_keep,
|
||||
top_k=top_k,
|
||||
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
|
||||
)
|
||||
|
||||
# Initialize the performance monitor with estimated token count
|
||||
estimated_input_tokens = len(tokenizer.encode(prompt))
|
||||
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
|
||||
|
||||
# Start measuring prefill phase
|
||||
perf_monitor.start_prefill()
|
||||
|
||||
results_generator = stream_generate(
|
||||
model, tokenizer, prompt=prompt, max_tokens=max_new_tokens, sampler=sampler
|
||||
)
|
||||
text = ""
|
||||
is_first = True
|
||||
for res in results_generator:
|
||||
new_text = res.text
|
||||
text += new_text
|
||||
# The prompt processing tokens-per-second.
|
||||
# prompt_tps = res.prompt_tps
|
||||
# The number of tokens in the prompt.
|
||||
prompt_tokens = res.prompt_tokens
|
||||
# The number of generated tokens.
|
||||
generation_tokens = res.generation_tokens
|
||||
# The tokens-per-second for generation.
|
||||
# generation_tps = res.generation_tps
|
||||
# The peak memory used so far in GB.
|
||||
# peak_memory = res.peak_memory
|
||||
# "length", "stop" or `None`
|
||||
finish_reason = res.finish_reason
|
||||
if (
|
||||
prompt.rstrip().endswith(think_start_token)
|
||||
and is_reasoning_model
|
||||
and is_first
|
||||
):
|
||||
text = think_start_token + "\n" + text
|
||||
is_first = False
|
||||
|
||||
msg = parse_chat_message(
|
||||
text,
|
||||
extract_reasoning=is_reasoning_model,
|
||||
reasoning_patterns=reasoning_patterns,
|
||||
)
|
||||
|
||||
# If this is the first iteration, update the input token count
|
||||
if perf_monitor.metrics.input_token_count != prompt_tokens:
|
||||
perf_monitor.metrics.input_token_count = prompt_tokens
|
||||
# Update performance metrics based on current token count
|
||||
perf_metrics = perf_monitor.on_tokens_received(generation_tokens)
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": generation_tokens,
|
||||
"total_tokens": prompt_tokens + generation_tokens,
|
||||
}
|
||||
# Check if generation is complete
|
||||
is_complete = finish_reason is not None
|
||||
if is_complete:
|
||||
perf_monitor.end_generation()
|
||||
usage.update(perf_metrics)
|
||||
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
msg.reasoning_content,
|
||||
error_code=0,
|
||||
usage=usage,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
)
|
Loading…
Reference in New Issue
Block a user