mirror of
https://github.com/csunny/DB-GPT.git
synced 2025-07-24 12:45:45 +00:00
feat(model): Add patch for Qwen3 moe (#2676)
This commit is contained in:
parent
3a65e1b65f
commit
4f39850ac1
@ -65,7 +65,14 @@ class ModelInferenceMetrics:
|
||||
"""The total number of tokens (prompt plus completion)."""
|
||||
|
||||
speed_per_second: Optional[float] = None
|
||||
"""The average number of tokens generated per second."""
|
||||
"""The average number of tokens generated per second. Includes both prefill and
|
||||
decode time."""
|
||||
|
||||
prefill_tokens_per_second: Optional[float] = None
|
||||
"""Prefill speed in tokens per second."""
|
||||
|
||||
decode_tokens_per_second: Optional[float] = None
|
||||
"""The average number of tokens generated per second during the decode phase."""
|
||||
|
||||
current_gpu_infos: Optional[List[GPUInfo]] = None
|
||||
"""Current gpu information, all devices"""
|
||||
@ -97,6 +104,12 @@ class ModelInferenceMetrics:
|
||||
completion_tokens = last_metrics.completion_tokens if last_metrics else None
|
||||
total_tokens = last_metrics.total_tokens if last_metrics else None
|
||||
speed_per_second = last_metrics.speed_per_second if last_metrics else None
|
||||
prefill_tokens_per_second = (
|
||||
last_metrics.prefill_tokens_per_second if last_metrics else None
|
||||
)
|
||||
decode_tokens_per_second = (
|
||||
last_metrics.decode_tokens_per_second if last_metrics else None
|
||||
)
|
||||
current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None
|
||||
avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None
|
||||
|
||||
@ -116,6 +129,8 @@ class ModelInferenceMetrics:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
speed_per_second=speed_per_second,
|
||||
prefill_tokens_per_second=prefill_tokens_per_second,
|
||||
decode_tokens_per_second=decode_tokens_per_second,
|
||||
current_gpu_infos=current_gpu_infos,
|
||||
avg_gpu_infos=avg_gpu_infos,
|
||||
)
|
||||
@ -124,6 +139,65 @@ class ModelInferenceMetrics:
|
||||
"""Convert the model inference metrics to dict."""
|
||||
return asdict(self)
|
||||
|
||||
def to_printable_string(self) -> str:
|
||||
"""Stringify the metrics in an elegant format.
|
||||
|
||||
Returns:
|
||||
str: A formatted string containing first token latency, prefill speed,
|
||||
decode speed, prompt tokens and completion tokens.
|
||||
"""
|
||||
lines = []
|
||||
|
||||
# Calculate first token latency if possible
|
||||
first_token_latency = None
|
||||
if self.first_token_time_ms is not None and self.start_time_ms is not None:
|
||||
first_token_latency = (
|
||||
self.first_token_time_ms - self.start_time_ms
|
||||
) / 1000.0
|
||||
|
||||
# Add section header
|
||||
lines.append("=== Model Inference Metrics ===")
|
||||
|
||||
# Latency metrics
|
||||
lines.append("\n▶ Latency:")
|
||||
if first_token_latency is not None:
|
||||
lines.append(f" • First Token Latency: {first_token_latency:.3f}s")
|
||||
else:
|
||||
lines.append(" • First Token Latency: N/A")
|
||||
|
||||
# Speed metrics
|
||||
lines.append("\n▶ Speed:")
|
||||
if self.prefill_tokens_per_second is not None:
|
||||
lines.append(
|
||||
f" • Prefill Speed: {self.prefill_tokens_per_second:.2f} tokens/s"
|
||||
)
|
||||
else:
|
||||
lines.append(" • Prefill Speed: N/A")
|
||||
|
||||
if self.decode_tokens_per_second is not None:
|
||||
lines.append(
|
||||
f" • Decode Speed: {self.decode_tokens_per_second:.2f} tokens/s"
|
||||
)
|
||||
else:
|
||||
lines.append(" • Decode Speed: N/A")
|
||||
|
||||
# Token counts
|
||||
lines.append("\n▶ Tokens:")
|
||||
if self.prompt_tokens is not None:
|
||||
lines.append(f" • Prompt Tokens: {self.prompt_tokens}")
|
||||
else:
|
||||
lines.append(" • Prompt Tokens: N/A")
|
||||
|
||||
if self.completion_tokens is not None:
|
||||
lines.append(f" • Completion Tokens: {self.completion_tokens}")
|
||||
else:
|
||||
lines.append(" • Completion Tokens: N/A")
|
||||
|
||||
if self.total_tokens is not None:
|
||||
lines.append(f" • Total Tokens: {self.total_tokens}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
@dataclass
|
||||
@PublicAPI(stability="beta")
|
||||
|
@ -223,6 +223,12 @@ class LLMModelAdapter(ABC):
|
||||
"""Load model and tokenizer"""
|
||||
raise NotImplementedError
|
||||
|
||||
def model_patch(
|
||||
self, deploy_model_params: LLMDeployModelParameters
|
||||
) -> Optional[Callable[[Any], Any]]:
|
||||
"""Patch function for model"""
|
||||
return None
|
||||
|
||||
def parse_max_length(self, model, tokenizer) -> Optional[int]:
|
||||
"""Parse the max_length of the model.
|
||||
|
||||
|
@ -617,6 +617,13 @@ class Qwen3Adapter(QwenAdapter):
|
||||
" transformers package."
|
||||
)
|
||||
|
||||
def model_patch(self, deploy_model_params: LLMDeployModelParameters):
|
||||
"""Apply the monkey patch to moe model for high inference speed."""
|
||||
|
||||
from ..llm.monkey_patch import apply_qwen3_moe_monkey_patch
|
||||
|
||||
return apply_qwen3_moe_monkey_patch
|
||||
|
||||
def is_reasoning_model(
|
||||
self,
|
||||
deploy_model_params: LLMDeployModelParameters,
|
||||
|
@ -2,7 +2,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Any, Callable, Dict, Optional, cast
|
||||
|
||||
from dbgpt.core.interface.parameter import LLMDeployModelParameters
|
||||
from dbgpt.model.adapter.base import LLMModelAdapter
|
||||
@ -146,8 +146,16 @@ def huggingface_loader(
|
||||
if model_params.attn_implementation:
|
||||
kwargs["attn_implementation"] = model_params.attn_implementation
|
||||
|
||||
model_patch = llm_adapter.model_patch(model_params)
|
||||
|
||||
model, tokenizer = _hf_try_load_default_quantization_model(
|
||||
model_path, llm_adapter, device, num_gpus, model_params, kwargs
|
||||
model_path,
|
||||
llm_adapter,
|
||||
device,
|
||||
num_gpus,
|
||||
model_params,
|
||||
kwargs,
|
||||
model_patch=model_patch,
|
||||
)
|
||||
if model:
|
||||
return model, tokenizer
|
||||
@ -176,7 +184,7 @@ def huggingface_loader(
|
||||
compress_module(model, device)
|
||||
|
||||
return _hf_handle_model_and_tokenizer(
|
||||
model, tokenizer, device, num_gpus, model_params
|
||||
model, tokenizer, device, num_gpus, model_params, model_patch=model_patch
|
||||
)
|
||||
|
||||
|
||||
@ -187,6 +195,7 @@ def _hf_try_load_default_quantization_model(
|
||||
num_gpus: int,
|
||||
model_params: HFLLMDeployModelParameters,
|
||||
kwargs: Dict[str, Any],
|
||||
model_patch: Optional[Callable[[Any], Any]] = None,
|
||||
):
|
||||
"""Try load default quantization model(Support by huggingface default)"""
|
||||
cloned_kwargs = {k: v for k, v in kwargs.items()}
|
||||
@ -216,7 +225,13 @@ def _hf_try_load_default_quantization_model(
|
||||
if model:
|
||||
logger.info(f"Load default quantization model {model_name} success")
|
||||
return _hf_handle_model_and_tokenizer(
|
||||
model, tokenizer, device, num_gpus, model_params, to=False
|
||||
model,
|
||||
tokenizer,
|
||||
device,
|
||||
num_gpus,
|
||||
model_params,
|
||||
to=False,
|
||||
model_patch=model_patch,
|
||||
)
|
||||
return None, None
|
||||
except Exception as e:
|
||||
@ -233,6 +248,7 @@ def _hf_handle_model_and_tokenizer(
|
||||
num_gpus: int,
|
||||
model_params: HFLLMDeployModelParameters,
|
||||
to: bool = True,
|
||||
model_patch: Optional[Callable[[Any], Any]] = None,
|
||||
):
|
||||
if (device == "cuda" and num_gpus == 1) or device == "mps" and tokenizer:
|
||||
# TODO: Check cpu_offloading
|
||||
@ -243,6 +259,11 @@ def _hf_handle_model_and_tokenizer(
|
||||
pass
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
if model_patch:
|
||||
model = model_patch(model)
|
||||
except Exception:
|
||||
pass
|
||||
if model_params.verbose:
|
||||
print(model)
|
||||
return model, tokenizer
|
||||
|
@ -197,7 +197,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
yield model_output
|
||||
logger.info(
|
||||
f"\n\nfull stream output:\n{previous_response}\n\nmodel "
|
||||
f"generate_stream params:\n{params}"
|
||||
f"generate_stream params:\n{params}\n"
|
||||
f"{last_metrics.to_printable_string()}"
|
||||
)
|
||||
model_span.end(metadata={"output": previous_response})
|
||||
span.end()
|
||||
@ -238,6 +239,10 @@ class DefaultModelWorker(ModelWorker):
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
last_metrics = current_metrics
|
||||
logger.info(
|
||||
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
for out in self.generate_stream(params):
|
||||
@ -320,7 +325,8 @@ class DefaultModelWorker(ModelWorker):
|
||||
yield model_output
|
||||
logger.info(
|
||||
f"\n\nfull stream output:\n{previous_response}\n\nmodel "
|
||||
f"generate_stream params:\n{params}"
|
||||
f"generate_stream params:\n{params}\n"
|
||||
f"{last_metrics.to_printable_string()}"
|
||||
)
|
||||
model_span.end(metadata={"output": previous_response})
|
||||
span.end()
|
||||
@ -359,6 +365,10 @@ class DefaultModelWorker(ModelWorker):
|
||||
last_metrics,
|
||||
is_first_generate,
|
||||
)
|
||||
last_metrics = current_metrics
|
||||
logger.info(
|
||||
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
|
||||
)
|
||||
return model_output
|
||||
else:
|
||||
output = None
|
||||
@ -576,6 +586,17 @@ def _new_metrics_from_model_output(
|
||||
# time cost(seconds)
|
||||
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
|
||||
metrics.speed_per_second = total_tokens / duration
|
||||
if total_tokens and "prefill_tokens_per_second" in usage:
|
||||
metrics.prefill_tokens_per_second = usage["prefill_tokens_per_second"]
|
||||
if total_tokens and "decode_tokens_per_second" in usage:
|
||||
# Decode speed
|
||||
metrics.decode_tokens_per_second = usage["decode_tokens_per_second"]
|
||||
elif total_tokens and metrics.first_token_time_ms:
|
||||
# time cost(seconds)
|
||||
duration = (metrics.current_time_ms - metrics.first_token_time_ms) / 1000.0
|
||||
if duration > 0:
|
||||
# Calculate decode speed if not provided
|
||||
metrics.decode_tokens_per_second = metrics.completion_tokens / duration
|
||||
|
||||
current_gpu_infos = _get_current_cuda_memory()
|
||||
metrics.current_gpu_infos = current_gpu_infos
|
||||
|
@ -2,10 +2,11 @@ import logging
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.hf_stream_utils import PerformanceMonitoringStreamer
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_END_TOKEN,
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
@ -28,13 +29,15 @@ def huggingface_chat_generate_stream(
|
||||
temperature = float(params.get("temperature", 0.7))
|
||||
top_p = float(params.get("top_p", 1.0))
|
||||
echo = params.get("echo", False)
|
||||
# max_new_tokens = int(params.get("max_new_tokens", 2048))
|
||||
max_new_tokens = int(params.get("max_new_tokens", 4096))
|
||||
stop_token_ids = params.get("stop_token_ids", [])
|
||||
do_sample = params.get("do_sample", True)
|
||||
custom_stop_words = params.get("custom_stop_words", [])
|
||||
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
|
||||
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
|
||||
is_reasoning_model = params.get("is_reasoning_model", False)
|
||||
use_cache = params.get("use_cache", True)
|
||||
cache_implementation = params.get("cache_implementation")
|
||||
reasoning_patterns = [
|
||||
{"start": think_start_token, "end": think_end_token},
|
||||
]
|
||||
@ -52,7 +55,9 @@ def huggingface_chat_generate_stream(
|
||||
token_kwargs["videos"] = videos
|
||||
if has_media:
|
||||
token_kwargs["padding"] = True
|
||||
tokenize_results = tokenizer(**token_kwargs).to(device)
|
||||
tokenize_results = tokenizer(**token_kwargs)
|
||||
input_token_count = tokenize_results.input_ids.shape[1] # Count input tokens
|
||||
tokenize_results = tokenize_results.to(device)
|
||||
#
|
||||
# if model.config.is_encoder_decoder:
|
||||
# max_src_len = context_len
|
||||
@ -63,29 +68,50 @@ def huggingface_chat_generate_stream(
|
||||
# # input_echo_len = len(input_ids)
|
||||
# input_ids = torch.as_tensor([input_ids], device=device)
|
||||
|
||||
streamer = TextIteratorStreamer(
|
||||
tokenizer, skip_prompt=not echo, skip_special_tokens=True
|
||||
streamer = PerformanceMonitoringStreamer(
|
||||
tokenizer,
|
||||
skip_prompt=not echo,
|
||||
skip_special_tokens=True,
|
||||
input_token_count=input_token_count,
|
||||
)
|
||||
|
||||
base_kwargs = {
|
||||
"max_length": context_len,
|
||||
"temperature": temperature,
|
||||
"streamer": streamer,
|
||||
"top_p": top_p,
|
||||
"use_cache": use_cache,
|
||||
"max_new_tokens": max_new_tokens,
|
||||
}
|
||||
|
||||
if stop_token_ids:
|
||||
base_kwargs["eos_token_id"] = stop_token_ids
|
||||
if do_sample is not None:
|
||||
base_kwargs["do_sample"] = do_sample
|
||||
if cache_implementation:
|
||||
base_kwargs["cache_implementation"] = cache_implementation
|
||||
|
||||
logger.info(
|
||||
f"Predict with parameters: {base_kwargs}\ncustom_stop_words: "
|
||||
f"{custom_stop_words}"
|
||||
)
|
||||
|
||||
generate_kwargs = {**tokenize_results, **base_kwargs}
|
||||
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
||||
|
||||
def generate_with_resilience():
|
||||
try:
|
||||
_outputs = model.generate(**generate_kwargs, return_dict_in_generate=True)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
logger.warning(
|
||||
f"OOM error occurred: {e}. Trying cleanup and retrying generation."
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
model.generate(**generate_kwargs)
|
||||
except Exception as ex:
|
||||
logger.error(f"Unexpected error during generation: {ex}")
|
||||
streamer.end()
|
||||
raise
|
||||
|
||||
streamer.start_prefill()
|
||||
thread = Thread(target=generate_with_resilience)
|
||||
thread.start()
|
||||
text = ""
|
||||
usage = None
|
||||
@ -111,6 +137,15 @@ def huggingface_chat_generate_stream(
|
||||
extract_reasoning=is_reasoning_model,
|
||||
reasoning_patterns=reasoning_patterns,
|
||||
)
|
||||
perf_metrics = streamer.get_performance_metrics()
|
||||
usage = {
|
||||
"prompt_tokens": perf_metrics["input_token_count"],
|
||||
"completion_tokens": perf_metrics["total_tokens_generated"],
|
||||
"total_tokens": perf_metrics["input_token_count"]
|
||||
+ perf_metrics["total_tokens_generated"],
|
||||
}
|
||||
usage.update(perf_metrics)
|
||||
|
||||
yield ModelOutput.build(
|
||||
msg.content,
|
||||
msg.reasoning_content,
|
||||
@ -118,3 +153,4 @@ def huggingface_chat_generate_stream(
|
||||
usage=usage,
|
||||
is_reasoning_model=is_reasoning_model,
|
||||
)
|
||||
thread.join()
|
||||
|
@ -7,6 +7,7 @@ from vllm.utils import random_uuid
|
||||
|
||||
from dbgpt.core import ModelOutput
|
||||
|
||||
from ...utils.llm_metrics import LLMPerformanceMonitor
|
||||
from ...utils.parse_utils import (
|
||||
_DEFAULT_THINK_END_TOKEN,
|
||||
_DEFAULT_THINK_START_TOKEN,
|
||||
@ -85,6 +86,13 @@ async def generate_stream(
|
||||
)
|
||||
# vocab = tokenizer.get_vocab()
|
||||
|
||||
# Initialize the performance monitor with estimated token count
|
||||
estimated_input_tokens = len(tokenizer.encode(prompt))
|
||||
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
|
||||
|
||||
# Start measuring prefill phase
|
||||
perf_monitor.start_prefill()
|
||||
|
||||
results_generator = model.generate(prompt, sampling_params, request_id)
|
||||
usage = None
|
||||
finish_reason = None
|
||||
@ -101,16 +109,30 @@ async def generate_stream(
|
||||
completion_tokens = sum(
|
||||
len(output.token_ids) for output in request_output.outputs
|
||||
)
|
||||
# If this is the first iteration, update the input token count
|
||||
if perf_monitor.metrics.input_token_count != prompt_tokens:
|
||||
perf_monitor.metrics.input_token_count = prompt_tokens
|
||||
|
||||
# Update performance metrics based on current token count
|
||||
perf_metrics = perf_monitor.on_tokens_received(completion_tokens)
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
# Add performance metrics to usage
|
||||
usage.update(perf_metrics)
|
||||
|
||||
finish_reason = (
|
||||
request_output.outputs[0].finish_reason
|
||||
if len(request_output.outputs) == 1
|
||||
else [output.finish_reason for output in request_output.outputs]
|
||||
)
|
||||
# Check if generation is complete
|
||||
is_complete = finish_reason is not None
|
||||
if is_complete:
|
||||
perf_monitor.end_generation()
|
||||
if text_outputs:
|
||||
# Tempora
|
||||
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:
|
||||
|
@ -5,6 +5,7 @@ import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from torch import nn
|
||||
|
||||
@ -119,3 +120,96 @@ def forward(
|
||||
def replace_llama_attn_with_non_inplace_operations():
|
||||
"""Avoid bugs in mps backend by not using in-place operations."""
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
|
||||
|
||||
|
||||
class ParQwen3MoeSparseMoeBlock(nn.Module):
|
||||
"""
|
||||
Adapted from https://huggingface.co/Qwen/Qwen3-30B-A3B/discussions/13
|
||||
"""
|
||||
|
||||
def __init__(self, base_moe):
|
||||
super().__init__()
|
||||
self.base_moe = base_moe
|
||||
self.num_experts = base_moe.num_experts
|
||||
self.top_k = base_moe.top_k
|
||||
self.norm_topk_prob = base_moe.norm_topk_prob
|
||||
|
||||
# gating
|
||||
self.gate = base_moe.gate
|
||||
self.experts = base_moe.experts
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(
|
||||
routing_weights, self.top_k, dim=-1
|
||||
)
|
||||
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
# we cast back to the input dtype
|
||||
routing_weights = routing_weights.to(hidden_states.dtype)
|
||||
|
||||
final_hidden_states = torch.zeros(
|
||||
(batch_size * sequence_length, hidden_dim),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# One hot encode the selected experts to create an expert mask
|
||||
# this will be used to easily index which expert is going to be sollicitated
|
||||
expert_mask = torch.nn.functional.one_hot(
|
||||
selected_experts, num_classes=self.num_experts
|
||||
).permute(2, 1, 0)
|
||||
|
||||
activated_experts = torch.unique(selected_experts)
|
||||
cuda_streams = [torch.cuda.Stream() for _ in activated_experts]
|
||||
# Loop over all available experts in the model and perform the computation on
|
||||
# each expert
|
||||
for expert_idx, cuda_stream in zip(activated_experts, cuda_streams):
|
||||
with torch.cuda.stream(cuda_stream):
|
||||
expert_layer = self.experts[expert_idx]
|
||||
idx, top_x = torch.where(expert_mask[expert_idx])
|
||||
|
||||
# Index the correct hidden states and compute the expert hidden state
|
||||
# for the current expert. We need to make sure to multiply the output
|
||||
# hidden states by `routing_weights` on the corresponding tokens
|
||||
# (top-1 and top-2)
|
||||
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
|
||||
current_hidden_states = (
|
||||
expert_layer(current_state) * routing_weights[top_x, idx, None]
|
||||
)
|
||||
|
||||
# However `index_add_` only support torch tensors for indexing so
|
||||
# we'll use the `top_x` tensor here.
|
||||
final_hidden_states.index_add_(
|
||||
0, top_x, current_hidden_states.to(hidden_states.dtype)
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
final_hidden_states = final_hidden_states.reshape(
|
||||
batch_size, sequence_length, hidden_dim
|
||||
)
|
||||
return final_hidden_states, router_logits
|
||||
|
||||
|
||||
def apply_qwen3_moe_monkey_patch(model):
|
||||
if not torch.cuda.is_available():
|
||||
# Only apply monkey patch if CUDA is available
|
||||
return model
|
||||
|
||||
for layer in model.model.layers:
|
||||
if type(layer.mlp).__name__ == "Qwen3MoeSparseMoeBlock":
|
||||
layer.mlp = ParQwen3MoeSparseMoeBlock(layer.mlp)
|
||||
return model
|
||||
|
||||
|
||||
def recovery_moe_monkey_patch(model):
|
||||
for layer in model.model.layers:
|
||||
if type(layer.mlp).__name__ == "ParQwen3MoeSparseMoeBlock" and hasattr(
|
||||
layer.mlp, "base_moe"
|
||||
):
|
||||
layer.mlp = layer.mlp.base_moe
|
||||
return model
|
||||
|
70
packages/dbgpt-core/src/dbgpt/model/utils/hf_stream_utils.py
Normal file
70
packages/dbgpt-core/src/dbgpt/model/utils/hf_stream_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
import logging
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from .llm_metrics import LLMPerformanceMonitor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerformanceMonitoringStreamer(TextIteratorStreamer):
|
||||
"""
|
||||
Extended TextIteratorStreamer that monitors LLM inference performance.
|
||||
Uses the generic LLMPerformanceMonitor for performance tracking.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
skip_prompt=False,
|
||||
timeout=None,
|
||||
input_token_count=0,
|
||||
**decode_kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tokenizer, skip_prompt=skip_prompt, timeout=timeout, **decode_kwargs
|
||||
)
|
||||
|
||||
# Initialize the performance monitor
|
||||
self.perf_monitor = LLMPerformanceMonitor(input_token_count=input_token_count)
|
||||
|
||||
# Additional flags for streamer-specific behavior
|
||||
self.is_prompt_token = True # Flag to track if current tokens are from prompt
|
||||
|
||||
def start_prefill(self):
|
||||
"""Mark the beginning of the prefill phase"""
|
||||
self.perf_monitor.start_prefill()
|
||||
|
||||
def put(self, value):
|
||||
"""
|
||||
Receive tokens and track performance metrics.
|
||||
Automatically detects prefill/decode phase transitions.
|
||||
"""
|
||||
# Skip counting if these are prompt tokens and skip_prompt is True
|
||||
if self.skip_prompt and self.is_prompt_token:
|
||||
self.is_prompt_token = False # Mark that we've processed the prompt tokens
|
||||
logger.debug("Skipping prompt tokens for performance measurement")
|
||||
super().put(value) # Call parent method to continue flow
|
||||
return
|
||||
|
||||
# Calculate number of new tokens
|
||||
token_count = len(value.tolist())
|
||||
total_token_count = self.perf_monitor.metrics.prev_tokens_count + token_count
|
||||
|
||||
# Update performance metrics
|
||||
self.perf_monitor.on_tokens_received(total_token_count)
|
||||
|
||||
# Call the parent method to continue the original flow
|
||||
super().put(value)
|
||||
|
||||
def end(self):
|
||||
"""End generation and finalize performance metrics"""
|
||||
# Finalize metrics
|
||||
self.perf_monitor.end_generation()
|
||||
|
||||
# Call the parent method to continue the original flow
|
||||
super().end()
|
||||
|
||||
def get_performance_metrics(self):
|
||||
"""Get performance metrics in a format suitable for API responses"""
|
||||
return self.perf_monitor.get_metrics_dict()
|
194
packages/dbgpt-core/src/dbgpt/model/utils/llm_metrics.py
Normal file
194
packages/dbgpt-core/src/dbgpt/model/utils/llm_metrics.py
Normal file
@ -0,0 +1,194 @@
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMPerformanceMetrics:
|
||||
"""Performance metrics for LLM inference, including prefill and decode phases"""
|
||||
|
||||
# Token counts
|
||||
input_token_count: int = 0
|
||||
total_tokens_generated: int = 0
|
||||
prev_tokens_count: int = 0
|
||||
|
||||
# Time measurements in nanoseconds
|
||||
start_time_ns: int = field(default_factory=time.time_ns)
|
||||
prefill_start_time_ns: Optional[int] = None
|
||||
prefill_end_time_ns: Optional[int] = None
|
||||
prefill_time_ns: Optional[int] = None
|
||||
decode_start_time_ns: Optional[int] = None
|
||||
total_time_ns: Optional[int] = None
|
||||
|
||||
# Timestamps and measurements
|
||||
token_timestamps_ns: List[int] = field(default_factory=list)
|
||||
decode_times_ns: List[int] = field(default_factory=list)
|
||||
|
||||
# Calculated metrics (tokens per second)
|
||||
prefill_tokens_per_second: Optional[float] = None
|
||||
decode_tokens_per_second: Optional[float] = None
|
||||
end_to_end_tokens_per_second: Optional[float] = None
|
||||
|
||||
# Additional computed values
|
||||
avg_decode_time: Optional[float] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, any]:
|
||||
"""Convert metrics to a dictionary for API response, with times in seconds"""
|
||||
metrics = {
|
||||
"input_token_count": self.input_token_count,
|
||||
"total_tokens_generated": self.total_tokens_generated,
|
||||
}
|
||||
|
||||
# Add time metrics in seconds
|
||||
if self.prefill_time_ns is not None:
|
||||
metrics["prefill_time"] = self.prefill_time_ns / 1e9
|
||||
metrics["prefill_tokens_per_second"] = self.prefill_tokens_per_second or 0
|
||||
|
||||
if self.total_time_ns is not None:
|
||||
metrics["total_time"] = self.total_time_ns / 1e9
|
||||
|
||||
if self.decode_times_ns:
|
||||
metrics["avg_decode_time"] = self.avg_decode_time
|
||||
metrics["decode_tokens_per_second"] = self.decode_tokens_per_second or 0
|
||||
|
||||
# Add throughput metrics
|
||||
metrics["end_to_end_tokens_per_second"] = self.end_to_end_tokens_per_second or 0
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
class LLMPerformanceMonitor:
|
||||
"""Generic performance monitor for LLM inference that tracks prefill and decode
|
||||
phases"""
|
||||
|
||||
def __init__(self, input_token_count: int = 0):
|
||||
# Performance metrics
|
||||
self.metrics = LLMPerformanceMetrics(input_token_count=input_token_count)
|
||||
|
||||
# Phase flags
|
||||
self.prefill_started: bool = False
|
||||
self.first_token_received: bool = False
|
||||
|
||||
def start_prefill(self) -> int:
|
||||
"""Mark the beginning of the prefill phase using nanosecond timestamp"""
|
||||
timestamp = time.time_ns()
|
||||
self.metrics.prefill_start_time_ns = timestamp
|
||||
self.prefill_started = True
|
||||
return timestamp
|
||||
|
||||
def on_tokens_received(self, current_token_count: int) -> Dict[str, any]:
|
||||
"""
|
||||
Called when new tokens are received from LLM
|
||||
Returns updated performance metrics
|
||||
"""
|
||||
current_time_ns = time.time_ns()
|
||||
|
||||
# Calculate new tokens received in this batch
|
||||
new_tokens = current_token_count - self.metrics.prev_tokens_count
|
||||
|
||||
# Auto-detect the end of prefill / start of decode phase
|
||||
if self.prefill_started and not self.first_token_received and new_tokens > 0:
|
||||
# This is the first tokens batch - mark the end of prefill phase
|
||||
self.metrics.prefill_end_time_ns = current_time_ns
|
||||
self.metrics.prefill_time_ns = (
|
||||
self.metrics.prefill_end_time_ns - self.metrics.prefill_start_time_ns
|
||||
)
|
||||
|
||||
# Convert nanoseconds to seconds for calculation and logging
|
||||
prefill_time_sec = self.metrics.prefill_time_ns / 1e9
|
||||
|
||||
# Calculate prefill speed
|
||||
if self.metrics.input_token_count > 0 and prefill_time_sec > 0:
|
||||
self.metrics.prefill_tokens_per_second = (
|
||||
self.metrics.input_token_count / prefill_time_sec
|
||||
)
|
||||
logger.info(
|
||||
f"Prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
|
||||
f"tokens/s for {self.metrics.input_token_count} tokens"
|
||||
)
|
||||
|
||||
# Mark the beginning of decode phase
|
||||
self.metrics.decode_start_time_ns = current_time_ns
|
||||
self.first_token_received = True
|
||||
|
||||
# Record token generation data
|
||||
if self.first_token_received and new_tokens > 0:
|
||||
# If we've already received tokens, add decode time for this batch
|
||||
if len(self.metrics.token_timestamps_ns) > 0:
|
||||
last_timestamp_ns = self.metrics.token_timestamps_ns[-1]
|
||||
token_decode_time_ns = current_time_ns - last_timestamp_ns
|
||||
|
||||
# Distribute the time evenly across all new tokens in this batch
|
||||
time_per_token = token_decode_time_ns / new_tokens
|
||||
for _ in range(new_tokens):
|
||||
self.metrics.decode_times_ns.append(time_per_token)
|
||||
|
||||
# Record the current token batch timestamp
|
||||
self.metrics.token_timestamps_ns.append(current_time_ns)
|
||||
self.metrics.total_tokens_generated += new_tokens
|
||||
self.metrics.prev_tokens_count = current_token_count
|
||||
|
||||
# Calculate current metrics
|
||||
self._update_metrics(current_time_ns)
|
||||
|
||||
return self.get_metrics_dict()
|
||||
|
||||
def _update_metrics(self, current_time_ns: Optional[int] = None) -> None:
|
||||
"""Update the performance metrics based on current state"""
|
||||
if current_time_ns is None:
|
||||
current_time_ns = time.time_ns()
|
||||
|
||||
# Record total time
|
||||
self.metrics.total_time_ns = current_time_ns - self.metrics.start_time_ns
|
||||
|
||||
# Calculate average decode speed
|
||||
if self.metrics.decode_times_ns:
|
||||
# Convert to seconds
|
||||
decode_times_sec = [t / 1e9 for t in self.metrics.decode_times_ns]
|
||||
self.metrics.avg_decode_time = sum(decode_times_sec) / len(decode_times_sec)
|
||||
self.metrics.decode_tokens_per_second = 1.0 / self.metrics.avg_decode_time
|
||||
|
||||
# Calculate end-to-end throughput
|
||||
total_time_sec = self.metrics.total_time_ns / 1e9
|
||||
if total_time_sec > 0:
|
||||
total_tokens = (
|
||||
self.metrics.input_token_count + self.metrics.total_tokens_generated
|
||||
)
|
||||
self.metrics.end_to_end_tokens_per_second = total_tokens / total_time_sec
|
||||
|
||||
def end_generation(self) -> Dict[str, any]:
|
||||
"""Mark the end of generation and finalize metrics"""
|
||||
current_time_ns = time.time_ns()
|
||||
self._update_metrics(current_time_ns)
|
||||
|
||||
# Log final performance data
|
||||
total_time_sec = self.metrics.total_time_ns / 1e9
|
||||
logger.info(f"Generation complete. Total time: {total_time_sec:.6f}s")
|
||||
|
||||
if self.metrics.prefill_tokens_per_second:
|
||||
logger.info(
|
||||
f"Final prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
|
||||
"tokens/s"
|
||||
)
|
||||
|
||||
if self.metrics.decode_tokens_per_second:
|
||||
logger.info(
|
||||
f"Final decode speed: {self.metrics.decode_tokens_per_second:.2f} "
|
||||
"tokens/s"
|
||||
)
|
||||
|
||||
if self.metrics.end_to_end_tokens_per_second:
|
||||
logger.info(
|
||||
"End-to-end throughput: "
|
||||
f"{self.metrics.end_to_end_tokens_per_second:.2f} tokens/s"
|
||||
)
|
||||
|
||||
return self.get_metrics_dict()
|
||||
|
||||
def get_metrics_dict(self) -> Dict[str, any]:
|
||||
"""Get performance metrics as dictionary, converting nanoseconds to seconds
|
||||
for external use"""
|
||||
return self.metrics.to_dict()
|
Loading…
Reference in New Issue
Block a user