feat(model): Add patch for Qwen3 moe (#2676)

This commit is contained in:
Fangyin Cheng 2025-05-12 14:44:25 +08:00 committed by GitHub
parent 3a65e1b65f
commit 4f39850ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 560 additions and 15 deletions

View File

@ -65,7 +65,14 @@ class ModelInferenceMetrics:
"""The total number of tokens (prompt plus completion).""" """The total number of tokens (prompt plus completion)."""
speed_per_second: Optional[float] = None speed_per_second: Optional[float] = None
"""The average number of tokens generated per second.""" """The average number of tokens generated per second. Includes both prefill and
decode time."""
prefill_tokens_per_second: Optional[float] = None
"""Prefill speed in tokens per second."""
decode_tokens_per_second: Optional[float] = None
"""The average number of tokens generated per second during the decode phase."""
current_gpu_infos: Optional[List[GPUInfo]] = None current_gpu_infos: Optional[List[GPUInfo]] = None
"""Current gpu information, all devices""" """Current gpu information, all devices"""
@ -97,6 +104,12 @@ class ModelInferenceMetrics:
completion_tokens = last_metrics.completion_tokens if last_metrics else None completion_tokens = last_metrics.completion_tokens if last_metrics else None
total_tokens = last_metrics.total_tokens if last_metrics else None total_tokens = last_metrics.total_tokens if last_metrics else None
speed_per_second = last_metrics.speed_per_second if last_metrics else None speed_per_second = last_metrics.speed_per_second if last_metrics else None
prefill_tokens_per_second = (
last_metrics.prefill_tokens_per_second if last_metrics else None
)
decode_tokens_per_second = (
last_metrics.decode_tokens_per_second if last_metrics else None
)
current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None
avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None
@ -116,6 +129,8 @@ class ModelInferenceMetrics:
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
speed_per_second=speed_per_second, speed_per_second=speed_per_second,
prefill_tokens_per_second=prefill_tokens_per_second,
decode_tokens_per_second=decode_tokens_per_second,
current_gpu_infos=current_gpu_infos, current_gpu_infos=current_gpu_infos,
avg_gpu_infos=avg_gpu_infos, avg_gpu_infos=avg_gpu_infos,
) )
@ -124,6 +139,65 @@ class ModelInferenceMetrics:
"""Convert the model inference metrics to dict.""" """Convert the model inference metrics to dict."""
return asdict(self) return asdict(self)
def to_printable_string(self) -> str:
"""Stringify the metrics in an elegant format.
Returns:
str: A formatted string containing first token latency, prefill speed,
decode speed, prompt tokens and completion tokens.
"""
lines = []
# Calculate first token latency if possible
first_token_latency = None
if self.first_token_time_ms is not None and self.start_time_ms is not None:
first_token_latency = (
self.first_token_time_ms - self.start_time_ms
) / 1000.0
# Add section header
lines.append("=== Model Inference Metrics ===")
# Latency metrics
lines.append("\n▶ Latency:")
if first_token_latency is not None:
lines.append(f" • First Token Latency: {first_token_latency:.3f}s")
else:
lines.append(" • First Token Latency: N/A")
# Speed metrics
lines.append("\n▶ Speed:")
if self.prefill_tokens_per_second is not None:
lines.append(
f" • Prefill Speed: {self.prefill_tokens_per_second:.2f} tokens/s"
)
else:
lines.append(" • Prefill Speed: N/A")
if self.decode_tokens_per_second is not None:
lines.append(
f" • Decode Speed: {self.decode_tokens_per_second:.2f} tokens/s"
)
else:
lines.append(" • Decode Speed: N/A")
# Token counts
lines.append("\n▶ Tokens:")
if self.prompt_tokens is not None:
lines.append(f" • Prompt Tokens: {self.prompt_tokens}")
else:
lines.append(" • Prompt Tokens: N/A")
if self.completion_tokens is not None:
lines.append(f" • Completion Tokens: {self.completion_tokens}")
else:
lines.append(" • Completion Tokens: N/A")
if self.total_tokens is not None:
lines.append(f" • Total Tokens: {self.total_tokens}")
return "\n".join(lines)
@dataclass @dataclass
@PublicAPI(stability="beta") @PublicAPI(stability="beta")

View File

@ -223,6 +223,12 @@ class LLMModelAdapter(ABC):
"""Load model and tokenizer""" """Load model and tokenizer"""
raise NotImplementedError raise NotImplementedError
def model_patch(
self, deploy_model_params: LLMDeployModelParameters
) -> Optional[Callable[[Any], Any]]:
"""Patch function for model"""
return None
def parse_max_length(self, model, tokenizer) -> Optional[int]: def parse_max_length(self, model, tokenizer) -> Optional[int]:
"""Parse the max_length of the model. """Parse the max_length of the model.

View File

@ -617,6 +617,13 @@ class Qwen3Adapter(QwenAdapter):
" transformers package." " transformers package."
) )
def model_patch(self, deploy_model_params: LLMDeployModelParameters):
"""Apply the monkey patch to moe model for high inference speed."""
from ..llm.monkey_patch import apply_qwen3_moe_monkey_patch
return apply_qwen3_moe_monkey_patch
def is_reasoning_model( def is_reasoning_model(
self, self,
deploy_model_params: LLMDeployModelParameters, deploy_model_params: LLMDeployModelParameters,

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
from typing import Any, Dict, Optional, cast from typing import Any, Callable, Dict, Optional, cast
from dbgpt.core.interface.parameter import LLMDeployModelParameters from dbgpt.core.interface.parameter import LLMDeployModelParameters
from dbgpt.model.adapter.base import LLMModelAdapter from dbgpt.model.adapter.base import LLMModelAdapter
@ -146,8 +146,16 @@ def huggingface_loader(
if model_params.attn_implementation: if model_params.attn_implementation:
kwargs["attn_implementation"] = model_params.attn_implementation kwargs["attn_implementation"] = model_params.attn_implementation
model_patch = llm_adapter.model_patch(model_params)
model, tokenizer = _hf_try_load_default_quantization_model( model, tokenizer = _hf_try_load_default_quantization_model(
model_path, llm_adapter, device, num_gpus, model_params, kwargs model_path,
llm_adapter,
device,
num_gpus,
model_params,
kwargs,
model_patch=model_patch,
) )
if model: if model:
return model, tokenizer return model, tokenizer
@ -176,7 +184,7 @@ def huggingface_loader(
compress_module(model, device) compress_module(model, device)
return _hf_handle_model_and_tokenizer( return _hf_handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params model, tokenizer, device, num_gpus, model_params, model_patch=model_patch
) )
@ -187,6 +195,7 @@ def _hf_try_load_default_quantization_model(
num_gpus: int, num_gpus: int,
model_params: HFLLMDeployModelParameters, model_params: HFLLMDeployModelParameters,
kwargs: Dict[str, Any], kwargs: Dict[str, Any],
model_patch: Optional[Callable[[Any], Any]] = None,
): ):
"""Try load default quantization model(Support by huggingface default)""" """Try load default quantization model(Support by huggingface default)"""
cloned_kwargs = {k: v for k, v in kwargs.items()} cloned_kwargs = {k: v for k, v in kwargs.items()}
@ -216,7 +225,13 @@ def _hf_try_load_default_quantization_model(
if model: if model:
logger.info(f"Load default quantization model {model_name} success") logger.info(f"Load default quantization model {model_name} success")
return _hf_handle_model_and_tokenizer( return _hf_handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params, to=False model,
tokenizer,
device,
num_gpus,
model_params,
to=False,
model_patch=model_patch,
) )
return None, None return None, None
except Exception as e: except Exception as e:
@ -233,6 +248,7 @@ def _hf_handle_model_and_tokenizer(
num_gpus: int, num_gpus: int,
model_params: HFLLMDeployModelParameters, model_params: HFLLMDeployModelParameters,
to: bool = True, to: bool = True,
model_patch: Optional[Callable[[Any], Any]] = None,
): ):
if (device == "cuda" and num_gpus == 1) or device == "mps" and tokenizer: if (device == "cuda" and num_gpus == 1) or device == "mps" and tokenizer:
# TODO: Check cpu_offloading # TODO: Check cpu_offloading
@ -243,6 +259,11 @@ def _hf_handle_model_and_tokenizer(
pass pass
except AttributeError: except AttributeError:
pass pass
try:
if model_patch:
model = model_patch(model)
except Exception:
pass
if model_params.verbose: if model_params.verbose:
print(model) print(model)
return model, tokenizer return model, tokenizer

View File

@ -197,7 +197,8 @@ class DefaultModelWorker(ModelWorker):
yield model_output yield model_output
logger.info( logger.info(
f"\n\nfull stream output:\n{previous_response}\n\nmodel " f"\n\nfull stream output:\n{previous_response}\n\nmodel "
f"generate_stream params:\n{params}" f"generate_stream params:\n{params}\n"
f"{last_metrics.to_printable_string()}"
) )
model_span.end(metadata={"output": previous_response}) model_span.end(metadata={"output": previous_response})
span.end() span.end()
@ -238,6 +239,10 @@ class DefaultModelWorker(ModelWorker):
last_metrics, last_metrics,
is_first_generate, is_first_generate,
) )
last_metrics = current_metrics
logger.info(
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
)
return model_output return model_output
else: else:
for out in self.generate_stream(params): for out in self.generate_stream(params):
@ -320,7 +325,8 @@ class DefaultModelWorker(ModelWorker):
yield model_output yield model_output
logger.info( logger.info(
f"\n\nfull stream output:\n{previous_response}\n\nmodel " f"\n\nfull stream output:\n{previous_response}\n\nmodel "
f"generate_stream params:\n{params}" f"generate_stream params:\n{params}\n"
f"{last_metrics.to_printable_string()}"
) )
model_span.end(metadata={"output": previous_response}) model_span.end(metadata={"output": previous_response})
span.end() span.end()
@ -359,6 +365,10 @@ class DefaultModelWorker(ModelWorker):
last_metrics, last_metrics,
is_first_generate, is_first_generate,
) )
last_metrics = current_metrics
logger.info(
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
)
return model_output return model_output
else: else:
output = None output = None
@ -576,6 +586,17 @@ def _new_metrics_from_model_output(
# time cost(seconds) # time cost(seconds)
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0 duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
metrics.speed_per_second = total_tokens / duration metrics.speed_per_second = total_tokens / duration
if total_tokens and "prefill_tokens_per_second" in usage:
metrics.prefill_tokens_per_second = usage["prefill_tokens_per_second"]
if total_tokens and "decode_tokens_per_second" in usage:
# Decode speed
metrics.decode_tokens_per_second = usage["decode_tokens_per_second"]
elif total_tokens and metrics.first_token_time_ms:
# time cost(seconds)
duration = (metrics.current_time_ms - metrics.first_token_time_ms) / 1000.0
if duration > 0:
# Calculate decode speed if not provided
metrics.decode_tokens_per_second = metrics.completion_tokens / duration
current_gpu_infos = _get_current_cuda_memory() current_gpu_infos = _get_current_cuda_memory()
metrics.current_gpu_infos = current_gpu_infos metrics.current_gpu_infos = current_gpu_infos

View File

@ -2,10 +2,11 @@ import logging
from threading import Thread from threading import Thread
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from transformers import AutoModelForCausalLM, AutoTokenizer
from dbgpt.core import ModelOutput from dbgpt.core import ModelOutput
from ...utils.hf_stream_utils import PerformanceMonitoringStreamer
from ...utils.parse_utils import ( from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN, _DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN, _DEFAULT_THINK_START_TOKEN,
@ -28,13 +29,15 @@ def huggingface_chat_generate_stream(
temperature = float(params.get("temperature", 0.7)) temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 1.0)) top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False) echo = params.get("echo", False)
# max_new_tokens = int(params.get("max_new_tokens", 2048)) max_new_tokens = int(params.get("max_new_tokens", 4096))
stop_token_ids = params.get("stop_token_ids", []) stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", True) do_sample = params.get("do_sample", True)
custom_stop_words = params.get("custom_stop_words", []) custom_stop_words = params.get("custom_stop_words", [])
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN) think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN) think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False) is_reasoning_model = params.get("is_reasoning_model", False)
use_cache = params.get("use_cache", True)
cache_implementation = params.get("cache_implementation")
reasoning_patterns = [ reasoning_patterns = [
{"start": think_start_token, "end": think_end_token}, {"start": think_start_token, "end": think_end_token},
] ]
@ -52,7 +55,9 @@ def huggingface_chat_generate_stream(
token_kwargs["videos"] = videos token_kwargs["videos"] = videos
if has_media: if has_media:
token_kwargs["padding"] = True token_kwargs["padding"] = True
tokenize_results = tokenizer(**token_kwargs).to(device) tokenize_results = tokenizer(**token_kwargs)
input_token_count = tokenize_results.input_ids.shape[1] # Count input tokens
tokenize_results = tokenize_results.to(device)
# #
# if model.config.is_encoder_decoder: # if model.config.is_encoder_decoder:
# max_src_len = context_len # max_src_len = context_len
@ -63,29 +68,50 @@ def huggingface_chat_generate_stream(
# # input_echo_len = len(input_ids) # # input_echo_len = len(input_ids)
# input_ids = torch.as_tensor([input_ids], device=device) # input_ids = torch.as_tensor([input_ids], device=device)
streamer = TextIteratorStreamer( streamer = PerformanceMonitoringStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True tokenizer,
skip_prompt=not echo,
skip_special_tokens=True,
input_token_count=input_token_count,
) )
base_kwargs = { base_kwargs = {
"max_length": context_len,
"temperature": temperature, "temperature": temperature,
"streamer": streamer, "streamer": streamer,
"top_p": top_p, "top_p": top_p,
"use_cache": use_cache,
"max_new_tokens": max_new_tokens,
} }
if stop_token_ids: if stop_token_ids:
base_kwargs["eos_token_id"] = stop_token_ids base_kwargs["eos_token_id"] = stop_token_ids
if do_sample is not None: if do_sample is not None:
base_kwargs["do_sample"] = do_sample base_kwargs["do_sample"] = do_sample
if cache_implementation:
base_kwargs["cache_implementation"] = cache_implementation
logger.info( logger.info(
f"Predict with parameters: {base_kwargs}\ncustom_stop_words: " f"Predict with parameters: {base_kwargs}\ncustom_stop_words: "
f"{custom_stop_words}" f"{custom_stop_words}"
) )
generate_kwargs = {**tokenize_results, **base_kwargs} generate_kwargs = {**tokenize_results, **base_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
def generate_with_resilience():
try:
_outputs = model.generate(**generate_kwargs, return_dict_in_generate=True)
except torch.cuda.OutOfMemoryError as e:
logger.warning(
f"OOM error occurred: {e}. Trying cleanup and retrying generation."
)
torch.cuda.empty_cache()
model.generate(**generate_kwargs)
except Exception as ex:
logger.error(f"Unexpected error during generation: {ex}")
streamer.end()
raise
streamer.start_prefill()
thread = Thread(target=generate_with_resilience)
thread.start() thread.start()
text = "" text = ""
usage = None usage = None
@ -111,6 +137,15 @@ def huggingface_chat_generate_stream(
extract_reasoning=is_reasoning_model, extract_reasoning=is_reasoning_model,
reasoning_patterns=reasoning_patterns, reasoning_patterns=reasoning_patterns,
) )
perf_metrics = streamer.get_performance_metrics()
usage = {
"prompt_tokens": perf_metrics["input_token_count"],
"completion_tokens": perf_metrics["total_tokens_generated"],
"total_tokens": perf_metrics["input_token_count"]
+ perf_metrics["total_tokens_generated"],
}
usage.update(perf_metrics)
yield ModelOutput.build( yield ModelOutput.build(
msg.content, msg.content,
msg.reasoning_content, msg.reasoning_content,
@ -118,3 +153,4 @@ def huggingface_chat_generate_stream(
usage=usage, usage=usage,
is_reasoning_model=is_reasoning_model, is_reasoning_model=is_reasoning_model,
) )
thread.join()

View File

@ -7,6 +7,7 @@ from vllm.utils import random_uuid
from dbgpt.core import ModelOutput from dbgpt.core import ModelOutput
from ...utils.llm_metrics import LLMPerformanceMonitor
from ...utils.parse_utils import ( from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN, _DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN, _DEFAULT_THINK_START_TOKEN,
@ -85,6 +86,13 @@ async def generate_stream(
) )
# vocab = tokenizer.get_vocab() # vocab = tokenizer.get_vocab()
# Initialize the performance monitor with estimated token count
estimated_input_tokens = len(tokenizer.encode(prompt))
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
# Start measuring prefill phase
perf_monitor.start_prefill()
results_generator = model.generate(prompt, sampling_params, request_id) results_generator = model.generate(prompt, sampling_params, request_id)
usage = None usage = None
finish_reason = None finish_reason = None
@ -101,16 +109,30 @@ async def generate_stream(
completion_tokens = sum( completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs len(output.token_ids) for output in request_output.outputs
) )
# If this is the first iteration, update the input token count
if perf_monitor.metrics.input_token_count != prompt_tokens:
perf_monitor.metrics.input_token_count = prompt_tokens
# Update performance metrics based on current token count
perf_metrics = perf_monitor.on_tokens_received(completion_tokens)
usage = { usage = {
"prompt_tokens": prompt_tokens, "prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens, "completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens, "total_tokens": prompt_tokens + completion_tokens,
} }
# Add performance metrics to usage
usage.update(perf_metrics)
finish_reason = ( finish_reason = (
request_output.outputs[0].finish_reason request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1 if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs] else [output.finish_reason for output in request_output.outputs]
) )
# Check if generation is complete
is_complete = finish_reason is not None
if is_complete:
perf_monitor.end_generation()
if text_outputs: if text_outputs:
# Tempora # Tempora
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model: if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:

View File

@ -5,6 +5,7 @@ import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F
import transformers import transformers
from torch import nn from torch import nn
@ -119,3 +120,96 @@ def forward(
def replace_llama_attn_with_non_inplace_operations(): def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations.""" """Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
class ParQwen3MoeSparseMoeBlock(nn.Module):
"""
Adapted from https://huggingface.co/Qwen/Qwen3-30B-A3B/discussions/13
"""
def __init__(self, base_moe):
super().__init__()
self.base_moe = base_moe
self.num_experts = base_moe.num_experts
self.top_k = base_moe.top_k
self.norm_topk_prob = base_moe.norm_topk_prob
# gating
self.gate = base_moe.gate
self.experts = base_moe.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
activated_experts = torch.unique(selected_experts)
cuda_streams = [torch.cuda.Stream() for _ in activated_experts]
# Loop over all available experts in the model and perform the computation on
# each expert
for expert_idx, cuda_stream in zip(activated_experts, cuda_streams):
with torch.cuda.stream(cuda_stream):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state
# for the current expert. We need to make sure to multiply the output
# hidden states by `routing_weights` on the corresponding tokens
# (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so
# we'll use the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
torch.cuda.synchronize()
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits
def apply_qwen3_moe_monkey_patch(model):
if not torch.cuda.is_available():
# Only apply monkey patch if CUDA is available
return model
for layer in model.model.layers:
if type(layer.mlp).__name__ == "Qwen3MoeSparseMoeBlock":
layer.mlp = ParQwen3MoeSparseMoeBlock(layer.mlp)
return model
def recovery_moe_monkey_patch(model):
for layer in model.model.layers:
if type(layer.mlp).__name__ == "ParQwen3MoeSparseMoeBlock" and hasattr(
layer.mlp, "base_moe"
):
layer.mlp = layer.mlp.base_moe
return model

View File

@ -0,0 +1,70 @@
import logging
from transformers import TextIteratorStreamer
from .llm_metrics import LLMPerformanceMonitor
logger = logging.getLogger(__name__)
class PerformanceMonitoringStreamer(TextIteratorStreamer):
"""
Extended TextIteratorStreamer that monitors LLM inference performance.
Uses the generic LLMPerformanceMonitor for performance tracking.
"""
def __init__(
self,
tokenizer,
skip_prompt=False,
timeout=None,
input_token_count=0,
**decode_kwargs,
):
super().__init__(
tokenizer, skip_prompt=skip_prompt, timeout=timeout, **decode_kwargs
)
# Initialize the performance monitor
self.perf_monitor = LLMPerformanceMonitor(input_token_count=input_token_count)
# Additional flags for streamer-specific behavior
self.is_prompt_token = True # Flag to track if current tokens are from prompt
def start_prefill(self):
"""Mark the beginning of the prefill phase"""
self.perf_monitor.start_prefill()
def put(self, value):
"""
Receive tokens and track performance metrics.
Automatically detects prefill/decode phase transitions.
"""
# Skip counting if these are prompt tokens and skip_prompt is True
if self.skip_prompt and self.is_prompt_token:
self.is_prompt_token = False # Mark that we've processed the prompt tokens
logger.debug("Skipping prompt tokens for performance measurement")
super().put(value) # Call parent method to continue flow
return
# Calculate number of new tokens
token_count = len(value.tolist())
total_token_count = self.perf_monitor.metrics.prev_tokens_count + token_count
# Update performance metrics
self.perf_monitor.on_tokens_received(total_token_count)
# Call the parent method to continue the original flow
super().put(value)
def end(self):
"""End generation and finalize performance metrics"""
# Finalize metrics
self.perf_monitor.end_generation()
# Call the parent method to continue the original flow
super().end()
def get_performance_metrics(self):
"""Get performance metrics in a format suitable for API responses"""
return self.perf_monitor.get_metrics_dict()

View File

@ -0,0 +1,194 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class LLMPerformanceMetrics:
"""Performance metrics for LLM inference, including prefill and decode phases"""
# Token counts
input_token_count: int = 0
total_tokens_generated: int = 0
prev_tokens_count: int = 0
# Time measurements in nanoseconds
start_time_ns: int = field(default_factory=time.time_ns)
prefill_start_time_ns: Optional[int] = None
prefill_end_time_ns: Optional[int] = None
prefill_time_ns: Optional[int] = None
decode_start_time_ns: Optional[int] = None
total_time_ns: Optional[int] = None
# Timestamps and measurements
token_timestamps_ns: List[int] = field(default_factory=list)
decode_times_ns: List[int] = field(default_factory=list)
# Calculated metrics (tokens per second)
prefill_tokens_per_second: Optional[float] = None
decode_tokens_per_second: Optional[float] = None
end_to_end_tokens_per_second: Optional[float] = None
# Additional computed values
avg_decode_time: Optional[float] = None
def to_dict(self) -> Dict[str, any]:
"""Convert metrics to a dictionary for API response, with times in seconds"""
metrics = {
"input_token_count": self.input_token_count,
"total_tokens_generated": self.total_tokens_generated,
}
# Add time metrics in seconds
if self.prefill_time_ns is not None:
metrics["prefill_time"] = self.prefill_time_ns / 1e9
metrics["prefill_tokens_per_second"] = self.prefill_tokens_per_second or 0
if self.total_time_ns is not None:
metrics["total_time"] = self.total_time_ns / 1e9
if self.decode_times_ns:
metrics["avg_decode_time"] = self.avg_decode_time
metrics["decode_tokens_per_second"] = self.decode_tokens_per_second or 0
# Add throughput metrics
metrics["end_to_end_tokens_per_second"] = self.end_to_end_tokens_per_second or 0
return metrics
class LLMPerformanceMonitor:
"""Generic performance monitor for LLM inference that tracks prefill and decode
phases"""
def __init__(self, input_token_count: int = 0):
# Performance metrics
self.metrics = LLMPerformanceMetrics(input_token_count=input_token_count)
# Phase flags
self.prefill_started: bool = False
self.first_token_received: bool = False
def start_prefill(self) -> int:
"""Mark the beginning of the prefill phase using nanosecond timestamp"""
timestamp = time.time_ns()
self.metrics.prefill_start_time_ns = timestamp
self.prefill_started = True
return timestamp
def on_tokens_received(self, current_token_count: int) -> Dict[str, any]:
"""
Called when new tokens are received from LLM
Returns updated performance metrics
"""
current_time_ns = time.time_ns()
# Calculate new tokens received in this batch
new_tokens = current_token_count - self.metrics.prev_tokens_count
# Auto-detect the end of prefill / start of decode phase
if self.prefill_started and not self.first_token_received and new_tokens > 0:
# This is the first tokens batch - mark the end of prefill phase
self.metrics.prefill_end_time_ns = current_time_ns
self.metrics.prefill_time_ns = (
self.metrics.prefill_end_time_ns - self.metrics.prefill_start_time_ns
)
# Convert nanoseconds to seconds for calculation and logging
prefill_time_sec = self.metrics.prefill_time_ns / 1e9
# Calculate prefill speed
if self.metrics.input_token_count > 0 and prefill_time_sec > 0:
self.metrics.prefill_tokens_per_second = (
self.metrics.input_token_count / prefill_time_sec
)
logger.info(
f"Prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
f"tokens/s for {self.metrics.input_token_count} tokens"
)
# Mark the beginning of decode phase
self.metrics.decode_start_time_ns = current_time_ns
self.first_token_received = True
# Record token generation data
if self.first_token_received and new_tokens > 0:
# If we've already received tokens, add decode time for this batch
if len(self.metrics.token_timestamps_ns) > 0:
last_timestamp_ns = self.metrics.token_timestamps_ns[-1]
token_decode_time_ns = current_time_ns - last_timestamp_ns
# Distribute the time evenly across all new tokens in this batch
time_per_token = token_decode_time_ns / new_tokens
for _ in range(new_tokens):
self.metrics.decode_times_ns.append(time_per_token)
# Record the current token batch timestamp
self.metrics.token_timestamps_ns.append(current_time_ns)
self.metrics.total_tokens_generated += new_tokens
self.metrics.prev_tokens_count = current_token_count
# Calculate current metrics
self._update_metrics(current_time_ns)
return self.get_metrics_dict()
def _update_metrics(self, current_time_ns: Optional[int] = None) -> None:
"""Update the performance metrics based on current state"""
if current_time_ns is None:
current_time_ns = time.time_ns()
# Record total time
self.metrics.total_time_ns = current_time_ns - self.metrics.start_time_ns
# Calculate average decode speed
if self.metrics.decode_times_ns:
# Convert to seconds
decode_times_sec = [t / 1e9 for t in self.metrics.decode_times_ns]
self.metrics.avg_decode_time = sum(decode_times_sec) / len(decode_times_sec)
self.metrics.decode_tokens_per_second = 1.0 / self.metrics.avg_decode_time
# Calculate end-to-end throughput
total_time_sec = self.metrics.total_time_ns / 1e9
if total_time_sec > 0:
total_tokens = (
self.metrics.input_token_count + self.metrics.total_tokens_generated
)
self.metrics.end_to_end_tokens_per_second = total_tokens / total_time_sec
def end_generation(self) -> Dict[str, any]:
"""Mark the end of generation and finalize metrics"""
current_time_ns = time.time_ns()
self._update_metrics(current_time_ns)
# Log final performance data
total_time_sec = self.metrics.total_time_ns / 1e9
logger.info(f"Generation complete. Total time: {total_time_sec:.6f}s")
if self.metrics.prefill_tokens_per_second:
logger.info(
f"Final prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
"tokens/s"
)
if self.metrics.decode_tokens_per_second:
logger.info(
f"Final decode speed: {self.metrics.decode_tokens_per_second:.2f} "
"tokens/s"
)
if self.metrics.end_to_end_tokens_per_second:
logger.info(
"End-to-end throughput: "
f"{self.metrics.end_to_end_tokens_per_second:.2f} tokens/s"
)
return self.get_metrics_dict()
def get_metrics_dict(self) -> Dict[str, any]:
"""Get performance metrics as dictionary, converting nanoseconds to seconds
for external use"""
return self.metrics.to_dict()