mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-09-07 12:00:46 +00:00
feat(model): Support MLX inference (#2781)
This commit is contained in:
@@ -37,6 +37,7 @@ def scan_model_providers():
|
||||
base_class=LLMDeployModelParameters,
|
||||
specific_files=[
|
||||
"vllm_adapter",
|
||||
"mlx_adapter",
|
||||
"hf_adapter",
|
||||
"llama_cpp_adapter",
|
||||
"llama_cpp_py_adapter",
|
||||
|
@@ -270,14 +270,18 @@ class LLMModelAdapter(ABC):
|
||||
):
|
||||
return True
|
||||
return (
|
||||
lower_model_name_or_path
|
||||
and "deepseek" in lower_model_name_or_path
|
||||
and (
|
||||
"r1" in lower_model_name_or_path
|
||||
or "reasoning" in lower_model_name_or_path
|
||||
or "reasoner" in lower_model_name_or_path
|
||||
(
|
||||
lower_model_name_or_path
|
||||
and "deepseek" in lower_model_name_or_path
|
||||
and (
|
||||
"r1" in lower_model_name_or_path
|
||||
or "reasoning" in lower_model_name_or_path
|
||||
or "reasoner" in lower_model_name_or_path
|
||||
)
|
||||
)
|
||||
) or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
||||
or (lower_model_name_or_path and "qwq" in lower_model_name_or_path)
|
||||
or (lower_model_name_or_path and "qwen3" in lower_model_name_or_path)
|
||||
)
|
||||
|
||||
def support_async(self) -> bool:
|
||||
"""Whether the loaded model supports asynchronous calls"""
|
||||
|
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
126
packages/dbgpt-core/src/dbgpt/model/adapter/mlx_adapter.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from dbgpt.core import ModelMessage
|
||||
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter, register_model_adapter
|
||||
from dbgpt.model.adapter.model_metadata import COMMON_HF_MODELS
|
||||
from dbgpt.model.adapter.template import ConversationAdapter
|
||||
from dbgpt.model.base import ModelType
|
||||
from dbgpt.util.i18n_utils import _
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MLXDeployModelParameters(LLMDeployModelParameters):
|
||||
"""Local deploy model parameters."""
|
||||
|
||||
provider: str = "mlx"
|
||||
|
||||
path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"order": -800,
|
||||
"help": _("The path of the model, if you want to deploy a local model."),
|
||||
},
|
||||
)
|
||||
device: Optional[str] = field(
|
||||
default="auto",
|
||||
metadata={
|
||||
"order": -700,
|
||||
"help": _(
|
||||
"Device to run model. If None, the device is automatically determined"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
concurrency: Optional[int] = field(
|
||||
default=100, metadata={"help": _("Model concurrency limit")}
|
||||
)
|
||||
|
||||
@property
|
||||
def real_model_path(self) -> Optional[str]:
|
||||
"""Get the real model path.
|
||||
|
||||
If deploy model is not local, return None.
|
||||
"""
|
||||
return self._resolve_root_path(self.path)
|
||||
|
||||
@property
|
||||
def real_device(self) -> Optional[str]:
|
||||
"""Get the real device."""
|
||||
return self.device or super().real_device
|
||||
|
||||
|
||||
class MLXModelAdapter(LLMModelAdapter):
|
||||
def match(
|
||||
self,
|
||||
provider: str,
|
||||
model_name: Optional[str] = None,
|
||||
model_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
return provider == ModelType.MLX
|
||||
|
||||
def model_type(self) -> str:
|
||||
return ModelType.MLX
|
||||
|
||||
def model_param_class(
|
||||
self, model_type: str = None
|
||||
) -> Type[MLXDeployModelParameters]:
|
||||
return MLXDeployModelParameters
|
||||
|
||||
def get_default_conv_template(
|
||||
self, model_name: str, model_path: str
|
||||
) -> Optional[ConversationAdapter]:
|
||||
return None
|
||||
|
||||
def load_from_params(self, params: MLXDeployModelParameters):
|
||||
"""Load model from parameters."""
|
||||
try:
|
||||
from mlx_lm import load
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"MLX model adapter requires mlx_lm package. "
|
||||
"Please install it with `pip install mlx-lm`."
|
||||
)
|
||||
raise
|
||||
|
||||
model_path = params.real_model_path
|
||||
model, tokenizer = load(model_path)
|
||||
return model, tokenizer
|
||||
|
||||
def support_generate_function(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_generate_stream_function(
|
||||
self, model, deploy_model_params: LLMDeployModelParameters
|
||||
):
|
||||
from dbgpt.model.llm.llm_out.mlx_llm import generate_stream
|
||||
|
||||
return generate_stream
|
||||
|
||||
def get_str_prompt(
|
||||
self,
|
||||
params: Dict,
|
||||
messages: List[ModelMessage],
|
||||
tokenizer: Any,
|
||||
prompt_template: str = None,
|
||||
convert_to_compatible_format: bool = False,
|
||||
) -> Optional[str]:
|
||||
if not tokenizer:
|
||||
raise ValueError("tokenizer is is None")
|
||||
if hasattr(tokenizer, "apply_chat_template"):
|
||||
messages = self.transform_model_messages(
|
||||
messages, convert_to_compatible_format
|
||||
)
|
||||
logger.debug(f"The messages after transform: \n{messages}")
|
||||
str_prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return str_prompt
|
||||
return None
|
||||
|
||||
|
||||
register_model_adapter(MLXModelAdapter, supported_models=COMMON_HF_MODELS)
|
@@ -17,6 +17,7 @@ class ModelType:
|
||||
LLAMA_CPP_SERVER = "llama.cpp.server"
|
||||
PROXY = "proxy"
|
||||
VLLM = "vllm"
|
||||
MLX = "mlx"
|
||||
# TODO, support more model type
|
||||
|
||||
|
||||
|
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
110
packages/dbgpt-core/src/dbgpt/model/llm/llm_out/mlx_llm.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from typing import Dict
|
||||
|
||||
import mlx.nn as nn
|
||||
from mlx_lm import stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.tokenizer_utils import TokenizerWrapper
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.llm_metrics import LLMPerformanceMonitor
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_END_TOKEN,
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
parse_chat_message,
|
||||
)
|
||||
|
||||
|
||||
def generate_stream(
|
||||
model: nn.Module,
|
||||
tokenizer: TokenizerWrapper,
|
||||
params: Dict,
|
||||
device: str,
|
||||
context_len: int,
|
||||
):
|
||||
prompt = params["prompt"]
|
||||
temperature = float(params.get("temperature", 0))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
top_k = params.get("top_k", 0)
|
||||
max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
# echo = bool(params.get("echo", True))
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
|
||||
reasoning_patterns = [
|
||||
{"start": think_start_token, "end": think_end_token},
|
||||
]
|
||||
sampler = make_sampler(
|
||||
temp=temperature,
|
||||
top_p=top_p,
|
||||
# min_p=min_p,
|
||||
# min_tokens_to_keep=min_tokens_to_keep,
|
||||
top_k=top_k,
|
||||
xtc_special_tokens=tokenizer.encode("\n") + list(tokenizer.eos_token_ids),
|
||||
)
|
||||
|
||||
# Initialize the performance monitor with estimated token count
|
||||
estimated_input_tokens = len(tokenizer.encode(prompt))
|
||||
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
|
||||
|
||||
# Start measuring prefill phase
|
||||
perf_monitor.start_prefill()
|
||||
|
||||
results_generator = stream_generate(
|
||||
model, tokenizer, prompt=prompt, max_tokens=max_new_tokens, sampler=sampler
|
||||
)
|
||||
text = ""
|
||||
is_first = True
|
||||
for res in results_generator:
|
||||
new_text = res.text
|
||||
text += new_text
|
||||
# The prompt processing tokens-per-second.
|
||||
# prompt_tps = res.prompt_tps
|
||||
# The number of tokens in the prompt.
|
||||
prompt_tokens = res.prompt_tokens
|
||||
# The number of generated tokens.
|
||||
generation_tokens = res.generation_tokens
|
||||
# The tokens-per-second for generation.
|
||||
# generation_tps = res.generation_tps
|
||||
# The peak memory used so far in GB.
|
||||
# peak_memory = res.peak_memory
|
||||
# "length", "stop" or `None`
|
||||
finish_reason = res.finish_reason
|
||||
if (
|
||||
prompt.rstrip().endswith(think_start_token)
|
||||
and is_reasoning_model
|
||||
and is_first
|
||||
):
|
||||
text = think_start_token + "\n" + text
|
||||
is_first = False
|
||||
|
||||
msg = parse_chat_message(
|
||||
text,
|
||||
extract_reasoning=is_reasoning_model,
|
||||
reasoning_patterns=reasoning_patterns,
|
||||
)
|
||||
|
||||
# If this is the first iteration, update the input token count
|
||||
if perf_monitor.metrics.input_token_count != prompt_tokens:
|
||||
perf_monitor.metrics.input_token_count = prompt_tokens
|
||||
# Update performance metrics based on current token count
|
||||
perf_metrics = perf_monitor.on_tokens_received(generation_tokens)
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": generation_tokens,
|
||||
"total_tokens": prompt_tokens + generation_tokens,
|
||||
}
|
||||
# Check if generation is complete
|
||||
is_complete = finish_reason is not None
|
||||
if is_complete:
|
||||
perf_monitor.end_generation()
|
||||
usage.update(perf_metrics)
|
||||
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
msg.reasoning_content,
|
||||
error_code=0,
|
||||
usage=usage,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
)
|
Reference in New Issue
Block a user