mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-23 12:21:08 +00:00
feat(model): Support MLX inference (#2781)
This commit is contained in:
parent
9084c6c19c
commit
d9d4d4b6bc
36
configs/dbgpt-local-mlx.toml
Normal file
36
configs/dbgpt-local-mlx.toml
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
[system]
|
||||||
|
# Load language from environment variable(It is set by the hook)
|
||||||
|
language = "${env:DBGPT_LANG:-zh}"
|
||||||
|
api_keys = []
|
||||||
|
encrypt_key = "your_secret_key"
|
||||||
|
|
||||||
|
# Server Configurations
|
||||||
|
[service.web]
|
||||||
|
host = "0.0.0.0"
|
||||||
|
port = 5670
|
||||||
|
|
||||||
|
[service.web.database]
|
||||||
|
type = "sqlite"
|
||||||
|
path = "pilot/meta_data/dbgpt.db"
|
||||||
|
|
||||||
|
[rag.storage]
|
||||||
|
[rag.storage.vector]
|
||||||
|
type = "chroma"
|
||||||
|
persist_path = "pilot/data"
|
||||||
|
|
||||||
|
# Model Configurations
|
||||||
|
[models]
|
||||||
|
[[models.llms]]
|
||||||
|
name = "Qwen/Qwen3-0.6B-MLX-4bit"
|
||||||
|
provider = "mlx"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
|
||||||
|
[[models.embeddings]]
|
||||||
|
name = "BAAI/bge-large-zh-v1.5"
|
||||||
|
provider = "hf"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
43
docs/docs/installation/advanced_usage/mlx_inference.md
Normal file
43
docs/docs/installation/advanced_usage/mlx_inference.md
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# MLX Inference
|
||||||
|
DB-GPT supports [MLX](https://github.com/ml-explore/mlx-lm) inference, a fast and easy-to-use LLM inference and service library.
|
||||||
|
|
||||||
|
## Install dependencies
|
||||||
|
|
||||||
|
`MLX` is an optional dependency in DB-GPT. You can install it by adding the extra `--extra "mlx"` when installing dependencies.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use uv to install dependencies needed for mlx
|
||||||
|
# Install core dependencies and select desired extensions
|
||||||
|
uv sync --all-packages \
|
||||||
|
--extra "base" \
|
||||||
|
--extra "hf" \
|
||||||
|
--extra "mlx" \
|
||||||
|
--extra "rag" \
|
||||||
|
--extra "storage_chromadb" \
|
||||||
|
--extra "quant_bnb" \
|
||||||
|
--extra "dbgpts"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Modify configuration file
|
||||||
|
|
||||||
|
After installing the dependencies, you can modify your configuration file to use the `mlx` provider.
|
||||||
|
|
||||||
|
```toml
|
||||||
|
# Model Configurations
|
||||||
|
[models]
|
||||||
|
[[models.llms]]
|
||||||
|
name = "Qwen/Qwen3-0.6B-MLX-4bit"
|
||||||
|
provider = "mlx"
|
||||||
|
# If not provided, the model will be downloaded from the Hugging Face model hub
|
||||||
|
# uncomment the following line to specify the model path in the local file system
|
||||||
|
# https://huggingface.co/Qwen/Qwen3-0.6B-MLX-4bit
|
||||||
|
# path = "the-model-path-in-the-local-file-system"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Step 3: Run the Model
|
||||||
|
|
||||||
|
You can run the model using the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run dbgpt start webserver --config {your_config_file}
|
||||||
|
```
|
@ -170,6 +170,10 @@ const sidebars = {
|
|||||||
type: 'doc',
|
type: 'doc',
|
||||||
id: 'installation/advanced_usage/vLLM_inference',
|
id: 'installation/advanced_usage/vLLM_inference',
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
type: 'doc',
|
||||||
|
id: 'installation/advanced_usage/mlx_inference',
|
||||||
|
},
|
||||||
{
|
{
|
||||||
type: 'doc',
|
type: 'doc',
|
||||||
id: 'installation/advanced_usage/Llamacpp_server',
|
id: 'installation/advanced_usage/Llamacpp_server',
|
||||||
|
542
docs/yarn.lock
542
docs/yarn.lock
File diff suppressed because it is too large
Load Diff
@ -68,6 +68,9 @@ vllm = [
|
|||||||
# Just support GPU version on Linux
|
# Just support GPU version on Linux
|
||||||
"vllm>=0.7.0; sys_platform == 'linux'",
|
"vllm>=0.7.0; sys_platform == 'linux'",
|
||||||
]
|
]
|
||||||
|
mlx = [
|
||||||
|
"mlx-lm>=0.25.2; sys_platform == 'darwin'",
|
||||||
|
]
|
||||||
# vllm_pascal = [
|
# vllm_pascal = [
|
||||||
# # https://github.com/sasha0552/pascal-pkgs-ci
|
# # https://github.com/sasha0552/pascal-pkgs-ci
|
||||||
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
# "vllm-pascal==0.7.2; sys_platform == 'linux'"
|
||||||
|
@ -279,7 +279,9 @@ def load_config(config_file: str = None) -> ApplicationConfig:
|
|||||||
from dbgpt.configs.model_config import ROOT_PATH as DBGPT_ROOT_PATH
|
from dbgpt.configs.model_config import ROOT_PATH as DBGPT_ROOT_PATH
|
||||||
|
|
||||||
if config_file is None:
|
if config_file is None:
|
||||||
config_file = os.path.join(DBGPT_ROOT_PATH, "configs", "dbgpt-siliconflow.toml")
|
config_file = os.path.join(
|
||||||
|
DBGPT_ROOT_PATH, "configs", "dbgpt-proxy-siliconflow.toml"
|
||||||
|
)
|
||||||
elif not os.path.isabs(config_file):
|
elif not os.path.isabs(config_file):
|
||||||
# If config_file is a relative path, make it relative to DBGPT_ROOT_PATH
|
# If config_file is a relative path, make it relative to DBGPT_ROOT_PATH
|
||||||
config_file = os.path.join(DBGPT_ROOT_PATH, config_file)
|
config_file = os.path.join(DBGPT_ROOT_PATH, config_file)
|
||||||
|
@ -37,6 +37,7 @@ def scan_model_providers():
|
|||||||
base_class=LLMDeployModelParameters,
|
base_class=LLMDeployModelParameters,
|
||||||
specific_files=[
|
specific_files=[
|
||||||
"vllm_adapter",
|
"vllm_adapter",
|
||||||
|
"mlx_adapter",
|
||||||
"hf_adapter",
|
"hf_adapter",
|
||||||
"llama_cpp_adapter",
|
"llama_cpp_adapter",
|
||||||
"llama_cpp_py_adapter",
|
"llama_cpp_py_adapter",
|
||||||
|
@ -270,14 +270,18 @@ class LLMModelAdapter(ABC):
|
|||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
return (
|
return (
|
||||||
lower_model_name_or_path
|
(
|
||||||
and "deepseek" in lower_model_name_or_path
|
lower_model_name_or_path
|
||||||
and (
|
and "deepseek" in lower_model_name_or_path
|
||||||
"r1" in lower_model_name_or_path
|
and (
|
||||||
or "reasoning" in lower_model_name_or_path
|
"r1" in lower_model_name_or_path
|
||||||
or "reasoner" in lower_model_name_or_path
|
or "reasoning" in lower_model_name_or_path
|
||||||
|
or "reasoner" in lower_model_name_or_path
|
||||||
|
)
|
||||||
)
|
)
|
||||||
) or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
||||||
|
or (lower_model_name_or_path and "qwen3" in lower_model_name_or_path)
|
||||||
|
)
|
||||||
|
|
||||||
def support_async(self) -> bool:
|
def support_async(self) -> bool:
|
||||||
"""Whether the loaded model supports asynchronous calls"""
|
"""Whether the loaded model supports asynchronous calls"""
|
||||||
|
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from dbgpt.core import ModelMessage
|
||||||
|
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||||
|
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||||
|
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
|
||||||
|
from dbgpt.model.adapter.template import ConversationAdapter
|
||||||
|
from dbgpt.model.base import ModelType
|
||||||
|
from dbgpt.util.i18n_utils import _
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLXDeployModelParameters(LLMDeployModelParameters):
|
||||||
|
"""Local deploy model parameters."""
|
||||||
|
|
||||||
|
provider: str = "mlx"
|
||||||
|
|
||||||
|
path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"order": -800,
|
||||||
|
"help": _("The path of the model, if you want to deploy a local model."),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
device: Optional[str] = field(
|
||||||
|
default="auto",
|
||||||
|
metadata={
|
||||||
|
"order": -700,
|
||||||
|
"help": _(
|
||||||
|
"Device to run model. If None, the device is automatically determined"
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
concurrency: Optional[int] = field(
|
||||||
|
default=100, metadata={"help": _("Model concurrency limit")}
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def real_model_path(self) -> Optional[str]:
|
||||||
|
"""Get the real model path.
|
||||||
|
|
||||||
|
If deploy model is not local, return None.
|
||||||
|
"""
|
||||||
|
return self._resolve_root_path(self.path)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def real_device(self) -> Optional[str]:
|
||||||
|
"""Get the real device."""
|
||||||
|
return self.device or super().real_device
|
||||||
|
|
||||||
|
|
||||||
|
class MLXModelAdapter(LLMModelAdapter):
|
||||||
|
def match(
|
||||||
|
self,
|
||||||
|
provider: str,
|
||||||
|
model_name: Optional[str] = None,
|
||||||
|
model_path: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
return provider == ModelType.MLX
|
||||||
|
|
||||||
|
def model_type(self) -> str:
|
||||||
|
return ModelType.MLX
|
||||||
|
|
||||||
|
def model_param_class(
|
||||||
|
self, model_type: str = None
|
||||||
|
) -> Type[MLXDeployModelParameters]:
|
||||||
|
return MLXDeployModelParameters
|
||||||
|
|
||||||
|
def get_default_conv_template(
|
||||||
|
self, model_name: str, model_path: str
|
||||||
|
) -> Optional[ConversationAdapter]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def load_from_params(self, params: MLXDeployModelParameters):
|
||||||
|
"""Load model from parameters."""
|
||||||
|
try:
|
||||||
|
from mlx_lm import load
|
||||||
|
except ImportError:
|
||||||
|
logger.error(
|
||||||
|
"MLX model adapter requires mlx_lm package. "
|
||||||
|
"Please install it with `pip install mlx-lm`."
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
model_path = params.real_model_path
|
||||||
|
model, tokenizer = load(model_path)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
def support_generate_function(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_generate_stream_function(
|
||||||
|
self, model, deploy_model_params: LLMDeployModelParameters
|
||||||
|
):
|
||||||
|
from dbgpt.model.llm.llm_out.mlx_llm import generate_stream
|
||||||
|
|
||||||
|
return generate_stream
|
||||||
|
|
||||||
|
def get_str_prompt(
|
||||||
|
self,
|
||||||
|
params: Dict,
|
||||||
|
messages: List[ModelMessage],
|
||||||
|
tokenizer: Any,
|
||||||
|
prompt_template: str = None,
|
||||||
|
convert_to_compatible_format: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
if not tokenizer:
|
||||||
|
raise ValueError("tokenizer is is None")
|
||||||
|
if hasattr(tokenizer, "apply_chat_template"):
|
||||||
|
messages = self.transform_model_messages(
|
||||||
|
messages, convert_to_compatible_format
|
||||||
|
)
|
||||||
|
logger.debug(f"The messages after transform: \n{messages}")
|
||||||
|
str_prompt = tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return str_prompt
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
register_model_adapter(MLXModelAdapter, supported_models=COMMON_HF_MODELS)
|
@ -17,6 +17,7 @@ class ModelType:
|
|||||||
LLAMA_CPP_SERVER = "llama.cpp.server"
|
LLAMA_CPP_SERVER = "llama.cpp.server"
|
||||||
PROXY = "proxy"
|
PROXY = "proxy"
|
||||||
VLLM = "vllm"
|
VLLM = "vllm"
|
||||||
|
MLX = "mlx"
|
||||||
# TODO, support more model type
|
# TODO, support more model type
|
||||||
|
|
||||||
|
|
||||||
|
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import mlx.nn as nn
|
||||||
|
from mlx_lm import stream_generate
|
||||||
|
from mlx_lm.sample_utils import make_sampler
|
||||||
|
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||||
|
|
||||||
|
from dbgpt.core import ModelOutput
|
||||||
|
|
||||||
|
from ...utils.llm_metrics import LLMPerformanceMonitor
|
||||||
|
from ...utils.parse_utils import (
|
||||||
|
_DEFAULT_THINK_END_TOKEN,
|
||||||
|
_DEFAULT_THINK_START_TOKEN,
|
||||||
|
parse_chat_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_stream(
|
||||||
|
model: nn.Module,
|
||||||
|
tokenizer: TokenizerWrapper,
|
||||||
|
params: Dict,
|
||||||
|
device: str,
|
||||||
|
context_len: int,
|
||||||
|
):
|
||||||
|
prompt = params["prompt"]
|
||||||
|
temperature = float(params.get("temperature", 0))
|
||||||
|
top_p = float(params.get("top_p", 1.0))
|
||||||
|
top_k = params.get("top_k", 0)
|
||||||
|
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||||
|
# echo = bool(params.get("echo", True))
|
||||||
|
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||||
|
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||||
|
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||||
|
|
||||||
|
reasoning_patterns = [
|
||||||
|
{"start": think_start_token, "end": think_end_token},
|
||||||
|
]
|
||||||
|
sampler = make_sampler(
|
||||||
|
temp=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
# min_p=min_p,
|
||||||
|
# min_tokens_to_keep=min_tokens_to_keep,
|
||||||
|
top_k=top_k,
|
||||||
|
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the performance monitor with estimated token count
|
||||||
|
estimated_input_tokens = len(tokenizer.encode(prompt))
|
||||||
|
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
|
||||||
|
|
||||||
|
# Start measuring prefill phase
|
||||||
|
perf_monitor.start_prefill()
|
||||||
|
|
||||||
|
results_generator = stream_generate(
|
||||||
|
model, tokenizer, prompt=prompt, max_tokens=max_new_tokens, sampler=sampler
|
||||||
|
)
|
||||||
|
text = ""
|
||||||
|
is_first = True
|
||||||
|
for res in results_generator:
|
||||||
|
new_text = res.text
|
||||||
|
text += new_text
|
||||||
|
# The prompt processing tokens-per-second.
|
||||||
|
# prompt_tps = res.prompt_tps
|
||||||
|
# The number of tokens in the prompt.
|
||||||
|
prompt_tokens = res.prompt_tokens
|
||||||
|
# The number of generated tokens.
|
||||||
|
generation_tokens = res.generation_tokens
|
||||||
|
# The tokens-per-second for generation.
|
||||||
|
# generation_tps = res.generation_tps
|
||||||
|
# The peak memory used so far in GB.
|
||||||
|
# peak_memory = res.peak_memory
|
||||||
|
# "length", "stop" or `None`
|
||||||
|
finish_reason = res.finish_reason
|
||||||
|
if (
|
||||||
|
prompt.rstrip().endswith(think_start_token)
|
||||||
|
and is_reasoning_model
|
||||||
|
and is_first
|
||||||
|
):
|
||||||
|
text = think_start_token + "\n" + text
|
||||||
|
is_first = False
|
||||||
|
|
||||||
|
msg = parse_chat_message(
|
||||||
|
text,
|
||||||
|
extract_reasoning=is_reasoning_model,
|
||||||
|
reasoning_patterns=reasoning_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is the first iteration, update the input token count
|
||||||
|
if perf_monitor.metrics.input_token_count != prompt_tokens:
|
||||||
|
perf_monitor.metrics.input_token_count = prompt_tokens
|
||||||
|
# Update performance metrics based on current token count
|
||||||
|
perf_metrics = perf_monitor.on_tokens_received(generation_tokens)
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
|
"completion_tokens": generation_tokens,
|
||||||
|
"total_tokens": prompt_tokens + generation_tokens,
|
||||||
|
}
|
||||||
|
# Check if generation is complete
|
||||||
|
is_complete = finish_reason is not None
|
||||||
|
if is_complete:
|
||||||
|
perf_monitor.end_generation()
|
||||||
|
usage.update(perf_metrics)
|
||||||
|
|
||||||
|
yield ModelOutput.build(
|
||||||
|
msg.content,
|
||||||
|
msg.reasoning_content,
|
||||||
|
error_code=0,
|
||||||
|
usage=usage,
|
||||||
|
is_reasoning_model=is_reasoning_model,
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user