feat(model): Add patch for Qwen3 moe (#2676)

This commit is contained in:
Fangyin Cheng 2025-05-12 14:44:25 +08:00 committed by GitHub
parent 3a65e1b65f
commit 4f39850ac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 560 additions and 15 deletions

View File

@ -65,7 +65,14 @@ class ModelInferenceMetrics:
"""The total number of tokens (prompt plus completion)."""
speed_per_second: Optional[float] = None
"""The average number of tokens generated per second."""
"""The average number of tokens generated per second. Includes both prefill and
decode time."""
prefill_tokens_per_second: Optional[float] = None
"""Prefill speed in tokens per second."""
decode_tokens_per_second: Optional[float] = None
"""The average number of tokens generated per second during the decode phase."""
current_gpu_infos: Optional[List[GPUInfo]] = None
"""Current gpu information, all devices"""
@ -97,6 +104,12 @@ class ModelInferenceMetrics:
completion_tokens = last_metrics.completion_tokens if last_metrics else None
total_tokens = last_metrics.total_tokens if last_metrics else None
speed_per_second = last_metrics.speed_per_second if last_metrics else None
prefill_tokens_per_second = (
last_metrics.prefill_tokens_per_second if last_metrics else None
)
decode_tokens_per_second = (
last_metrics.decode_tokens_per_second if last_metrics else None
)
current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None
avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None
@ -116,6 +129,8 @@ class ModelInferenceMetrics:
completion_tokens=completion_tokens,
total_tokens=total_tokens,
speed_per_second=speed_per_second,
prefill_tokens_per_second=prefill_tokens_per_second,
decode_tokens_per_second=decode_tokens_per_second,
current_gpu_infos=current_gpu_infos,
avg_gpu_infos=avg_gpu_infos,
)
@ -124,6 +139,65 @@ class ModelInferenceMetrics:
"""Convert the model inference metrics to dict."""
return asdict(self)
def to_printable_string(self) -> str:
"""Stringify the metrics in an elegant format.
Returns:
str: A formatted string containing first token latency, prefill speed,
decode speed, prompt tokens and completion tokens.
"""
lines = []
# Calculate first token latency if possible
first_token_latency = None
if self.first_token_time_ms is not None and self.start_time_ms is not None:
first_token_latency = (
self.first_token_time_ms - self.start_time_ms
) / 1000.0
# Add section header
lines.append("=== Model Inference Metrics ===")
# Latency metrics
lines.append("\n▶ Latency:")
if first_token_latency is not None:
lines.append(f" • First Token Latency: {first_token_latency:.3f}s")
else:
lines.append(" • First Token Latency: N/A")
# Speed metrics
lines.append("\n▶ Speed:")
if self.prefill_tokens_per_second is not None:
lines.append(
f" • Prefill Speed: {self.prefill_tokens_per_second:.2f} tokens/s"
)
else:
lines.append(" • Prefill Speed: N/A")
if self.decode_tokens_per_second is not None:
lines.append(
f" • Decode Speed: {self.decode_tokens_per_second:.2f} tokens/s"
)
else:
lines.append(" • Decode Speed: N/A")
# Token counts
lines.append("\n▶ Tokens:")
if self.prompt_tokens is not None:
lines.append(f" • Prompt Tokens: {self.prompt_tokens}")
else:
lines.append(" • Prompt Tokens: N/A")
if self.completion_tokens is not None:
lines.append(f" • Completion Tokens: {self.completion_tokens}")
else:
lines.append(" • Completion Tokens: N/A")
if self.total_tokens is not None:
lines.append(f" • Total Tokens: {self.total_tokens}")
return "\n".join(lines)
@dataclass
@PublicAPI(stability="beta")

View File

@ -223,6 +223,12 @@ class LLMModelAdapter(ABC):
"""Load model and tokenizer"""
raise NotImplementedError
def model_patch(
self, deploy_model_params: LLMDeployModelParameters
) -> Optional[Callable[[Any], Any]]:
"""Patch function for model"""
return None
def parse_max_length(self, model, tokenizer) -> Optional[int]:
"""Parse the max_length of the model.

View File

@ -617,6 +617,13 @@ class Qwen3Adapter(QwenAdapter):
" transformers package."
)
def model_patch(self, deploy_model_params: LLMDeployModelParameters):
"""Apply the monkey patch to moe model for high inference speed."""
from ..llm.monkey_patch import apply_qwen3_moe_monkey_patch
return apply_qwen3_moe_monkey_patch
def is_reasoning_model(
self,
deploy_model_params: LLMDeployModelParameters,

View File

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
import logging
from typing import Any, Dict, Optional, cast
from typing import Any, Callable, Dict, Optional, cast
from dbgpt.core.interface.parameter import LLMDeployModelParameters
from dbgpt.model.adapter.base import LLMModelAdapter
@ -146,8 +146,16 @@ def huggingface_loader(
if model_params.attn_implementation:
kwargs["attn_implementation"] = model_params.attn_implementation
model_patch = llm_adapter.model_patch(model_params)
model, tokenizer = _hf_try_load_default_quantization_model(
model_path, llm_adapter, device, num_gpus, model_params, kwargs
model_path,
llm_adapter,
device,
num_gpus,
model_params,
kwargs,
model_patch=model_patch,
)
if model:
return model, tokenizer
@ -176,7 +184,7 @@ def huggingface_loader(
compress_module(model, device)
return _hf_handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params
model, tokenizer, device, num_gpus, model_params, model_patch=model_patch
)
@ -187,6 +195,7 @@ def _hf_try_load_default_quantization_model(
num_gpus: int,
model_params: HFLLMDeployModelParameters,
kwargs: Dict[str, Any],
model_patch: Optional[Callable[[Any], Any]] = None,
):
"""Try load default quantization model(Support by huggingface default)"""
cloned_kwargs = {k: v for k, v in kwargs.items()}
@ -216,7 +225,13 @@ def _hf_try_load_default_quantization_model(
if model:
logger.info(f"Load default quantization model {model_name} success")
return _hf_handle_model_and_tokenizer(
model, tokenizer, device, num_gpus, model_params, to=False
model,
tokenizer,
device,
num_gpus,
model_params,
to=False,
model_patch=model_patch,
)
return None, None
except Exception as e:
@ -233,6 +248,7 @@ def _hf_handle_model_and_tokenizer(
num_gpus: int,
model_params: HFLLMDeployModelParameters,
to: bool = True,
model_patch: Optional[Callable[[Any], Any]] = None,
):
if (device == "cuda" and num_gpus == 1) or device == "mps" and tokenizer:
# TODO: Check cpu_offloading
@ -243,6 +259,11 @@ def _hf_handle_model_and_tokenizer(
pass
except AttributeError:
pass
try:
if model_patch:
model = model_patch(model)
except Exception:
pass
if model_params.verbose:
print(model)
return model, tokenizer

View File

@ -197,7 +197,8 @@ class DefaultModelWorker(ModelWorker):
yield model_output
logger.info(
f"\n\nfull stream output:\n{previous_response}\n\nmodel "
f"generate_stream params:\n{params}"
f"generate_stream params:\n{params}\n"
f"{last_metrics.to_printable_string()}"
)
model_span.end(metadata={"output": previous_response})
span.end()
@ -238,6 +239,10 @@ class DefaultModelWorker(ModelWorker):
last_metrics,
is_first_generate,
)
last_metrics = current_metrics
logger.info(
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
)
return model_output
else:
for out in self.generate_stream(params):
@ -320,7 +325,8 @@ class DefaultModelWorker(ModelWorker):
yield model_output
logger.info(
f"\n\nfull stream output:\n{previous_response}\n\nmodel "
f"generate_stream params:\n{params}"
f"generate_stream params:\n{params}\n"
f"{last_metrics.to_printable_string()}"
)
model_span.end(metadata={"output": previous_response})
span.end()
@ -359,6 +365,10 @@ class DefaultModelWorker(ModelWorker):
last_metrics,
is_first_generate,
)
last_metrics = current_metrics
logger.info(
f"generate params:\n{params}\n{last_metrics.to_printable_string()}"
)
return model_output
else:
output = None
@ -576,6 +586,17 @@ def _new_metrics_from_model_output(
# time cost(seconds)
duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0
metrics.speed_per_second = total_tokens / duration
if total_tokens and "prefill_tokens_per_second" in usage:
metrics.prefill_tokens_per_second = usage["prefill_tokens_per_second"]
if total_tokens and "decode_tokens_per_second" in usage:
# Decode speed
metrics.decode_tokens_per_second = usage["decode_tokens_per_second"]
elif total_tokens and metrics.first_token_time_ms:
# time cost(seconds)
duration = (metrics.current_time_ms - metrics.first_token_time_ms) / 1000.0
if duration > 0:
# Calculate decode speed if not provided
metrics.decode_tokens_per_second = metrics.completion_tokens / duration
current_gpu_infos = _get_current_cuda_memory()
metrics.current_gpu_infos = current_gpu_infos

View File

@ -2,10 +2,11 @@ import logging
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers import AutoModelForCausalLM, AutoTokenizer
from dbgpt.core import ModelOutput
from ...utils.hf_stream_utils import PerformanceMonitoringStreamer
from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN,
@ -28,13 +29,15 @@ def huggingface_chat_generate_stream(
temperature = float(params.get("temperature", 0.7))
top_p = float(params.get("top_p", 1.0))
echo = params.get("echo", False)
# max_new_tokens = int(params.get("max_new_tokens", 2048))
max_new_tokens = int(params.get("max_new_tokens", 4096))
stop_token_ids = params.get("stop_token_ids", [])
do_sample = params.get("do_sample", True)
custom_stop_words = params.get("custom_stop_words", [])
think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN)
think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN)
is_reasoning_model = params.get("is_reasoning_model", False)
use_cache = params.get("use_cache", True)
cache_implementation = params.get("cache_implementation")
reasoning_patterns = [
{"start": think_start_token, "end": think_end_token},
]
@ -52,7 +55,9 @@ def huggingface_chat_generate_stream(
token_kwargs["videos"] = videos
if has_media:
token_kwargs["padding"] = True
tokenize_results = tokenizer(**token_kwargs).to(device)
tokenize_results = tokenizer(**token_kwargs)
input_token_count = tokenize_results.input_ids.shape[1] # Count input tokens
tokenize_results = tokenize_results.to(device)
#
# if model.config.is_encoder_decoder:
# max_src_len = context_len
@ -63,29 +68,50 @@ def huggingface_chat_generate_stream(
# # input_echo_len = len(input_ids)
# input_ids = torch.as_tensor([input_ids], device=device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=not echo, skip_special_tokens=True
streamer = PerformanceMonitoringStreamer(
tokenizer,
skip_prompt=not echo,
skip_special_tokens=True,
input_token_count=input_token_count,
)
base_kwargs = {
"max_length": context_len,
"temperature": temperature,
"streamer": streamer,
"top_p": top_p,
"use_cache": use_cache,
"max_new_tokens": max_new_tokens,
}
if stop_token_ids:
base_kwargs["eos_token_id"] = stop_token_ids
if do_sample is not None:
base_kwargs["do_sample"] = do_sample
if cache_implementation:
base_kwargs["cache_implementation"] = cache_implementation
logger.info(
f"Predict with parameters: {base_kwargs}\ncustom_stop_words: "
f"{custom_stop_words}"
)
generate_kwargs = {**tokenize_results, **base_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
def generate_with_resilience():
try:
_outputs = model.generate(**generate_kwargs, return_dict_in_generate=True)
except torch.cuda.OutOfMemoryError as e:
logger.warning(
f"OOM error occurred: {e}. Trying cleanup and retrying generation."
)
torch.cuda.empty_cache()
model.generate(**generate_kwargs)
except Exception as ex:
logger.error(f"Unexpected error during generation: {ex}")
streamer.end()
raise
streamer.start_prefill()
thread = Thread(target=generate_with_resilience)
thread.start()
text = ""
usage = None
@ -111,6 +137,15 @@ def huggingface_chat_generate_stream(
extract_reasoning=is_reasoning_model,
reasoning_patterns=reasoning_patterns,
)
perf_metrics = streamer.get_performance_metrics()
usage = {
"prompt_tokens": perf_metrics["input_token_count"],
"completion_tokens": perf_metrics["total_tokens_generated"],
"total_tokens": perf_metrics["input_token_count"]
+ perf_metrics["total_tokens_generated"],
}
usage.update(perf_metrics)
yield ModelOutput.build(
msg.content,
msg.reasoning_content,
@ -118,3 +153,4 @@ def huggingface_chat_generate_stream(
usage=usage,
is_reasoning_model=is_reasoning_model,
)
thread.join()

View File

@ -7,6 +7,7 @@ from vllm.utils import random_uuid
from dbgpt.core import ModelOutput
from ...utils.llm_metrics import LLMPerformanceMonitor
from ...utils.parse_utils import (
_DEFAULT_THINK_END_TOKEN,
_DEFAULT_THINK_START_TOKEN,
@ -85,6 +86,13 @@ async def generate_stream(
)
# vocab = tokenizer.get_vocab()
# Initialize the performance monitor with estimated token count
estimated_input_tokens = len(tokenizer.encode(prompt))
perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens)
# Start measuring prefill phase
perf_monitor.start_prefill()
results_generator = model.generate(prompt, sampling_params, request_id)
usage = None
finish_reason = None
@ -101,16 +109,30 @@ async def generate_stream(
completion_tokens = sum(
len(output.token_ids) for output in request_output.outputs
)
# If this is the first iteration, update the input token count
if perf_monitor.metrics.input_token_count != prompt_tokens:
perf_monitor.metrics.input_token_count = prompt_tokens
# Update performance metrics based on current token count
perf_metrics = perf_monitor.on_tokens_received(completion_tokens)
usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
# Add performance metrics to usage
usage.update(perf_metrics)
finish_reason = (
request_output.outputs[0].finish_reason
if len(request_output.outputs) == 1
else [output.finish_reason for output in request_output.outputs]
)
# Check if generation is complete
is_complete = finish_reason is not None
if is_complete:
perf_monitor.end_generation()
if text_outputs:
# Tempora
if prompt.rstrip().endswith(think_start_token) and is_reasoning_model:

View File

@ -5,6 +5,7 @@ import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from torch import nn
@ -119,3 +120,96 @@ def forward(
def replace_llama_attn_with_non_inplace_operations():
"""Avoid bugs in mps backend by not using in-place operations."""
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
class ParQwen3MoeSparseMoeBlock(nn.Module):
"""
Adapted from https://huggingface.co/Qwen/Qwen3-30B-A3B/discussions/13
"""
def __init__(self, base_moe):
super().__init__()
self.base_moe = base_moe
self.num_experts = base_moe.num_experts
self.top_k = base_moe.top_k
self.norm_topk_prob = base_moe.norm_topk_prob
# gating
self.gate = base_moe.gate
self.experts = base_moe.experts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
if self.norm_topk_prob: # only diff with mixtral sparse moe block!
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
activated_experts = torch.unique(selected_experts)
cuda_streams = [torch.cuda.Stream() for _ in activated_experts]
# Loop over all available experts in the model and perform the computation on
# each expert
for expert_idx, cuda_stream in zip(activated_experts, cuda_streams):
with torch.cuda.stream(cuda_stream):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state
# for the current expert. We need to make sure to multiply the output
# hidden states by `routing_weights` on the corresponding tokens
# (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
current_hidden_states = (
expert_layer(current_state) * routing_weights[top_x, idx, None]
)
# However `index_add_` only support torch tensors for indexing so
# we'll use the `top_x` tensor here.
final_hidden_states.index_add_(
0, top_x, current_hidden_states.to(hidden_states.dtype)
)
torch.cuda.synchronize()
final_hidden_states = final_hidden_states.reshape(
batch_size, sequence_length, hidden_dim
)
return final_hidden_states, router_logits
def apply_qwen3_moe_monkey_patch(model):
if not torch.cuda.is_available():
# Only apply monkey patch if CUDA is available
return model
for layer in model.model.layers:
if type(layer.mlp).__name__ == "Qwen3MoeSparseMoeBlock":
layer.mlp = ParQwen3MoeSparseMoeBlock(layer.mlp)
return model
def recovery_moe_monkey_patch(model):
for layer in model.model.layers:
if type(layer.mlp).__name__ == "ParQwen3MoeSparseMoeBlock" and hasattr(
layer.mlp, "base_moe"
):
layer.mlp = layer.mlp.base_moe
return model

View File

@ -0,0 +1,70 @@
import logging
from transformers import TextIteratorStreamer
from .llm_metrics import LLMPerformanceMonitor
logger = logging.getLogger(__name__)
class PerformanceMonitoringStreamer(TextIteratorStreamer):
"""
Extended TextIteratorStreamer that monitors LLM inference performance.
Uses the generic LLMPerformanceMonitor for performance tracking.
"""
def __init__(
self,
tokenizer,
skip_prompt=False,
timeout=None,
input_token_count=0,
**decode_kwargs,
):
super().__init__(
tokenizer, skip_prompt=skip_prompt, timeout=timeout, **decode_kwargs
)
# Initialize the performance monitor
self.perf_monitor = LLMPerformanceMonitor(input_token_count=input_token_count)
# Additional flags for streamer-specific behavior
self.is_prompt_token = True # Flag to track if current tokens are from prompt
def start_prefill(self):
"""Mark the beginning of the prefill phase"""
self.perf_monitor.start_prefill()
def put(self, value):
"""
Receive tokens and track performance metrics.
Automatically detects prefill/decode phase transitions.
"""
# Skip counting if these are prompt tokens and skip_prompt is True
if self.skip_prompt and self.is_prompt_token:
self.is_prompt_token = False # Mark that we've processed the prompt tokens
logger.debug("Skipping prompt tokens for performance measurement")
super().put(value) # Call parent method to continue flow
return
# Calculate number of new tokens
token_count = len(value.tolist())
total_token_count = self.perf_monitor.metrics.prev_tokens_count + token_count
# Update performance metrics
self.perf_monitor.on_tokens_received(total_token_count)
# Call the parent method to continue the original flow
super().put(value)
def end(self):
"""End generation and finalize performance metrics"""
# Finalize metrics
self.perf_monitor.end_generation()
# Call the parent method to continue the original flow
super().end()
def get_performance_metrics(self):
"""Get performance metrics in a format suitable for API responses"""
return self.perf_monitor.get_metrics_dict()

View File

@ -0,0 +1,194 @@
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class LLMPerformanceMetrics:
"""Performance metrics for LLM inference, including prefill and decode phases"""
# Token counts
input_token_count: int = 0
total_tokens_generated: int = 0
prev_tokens_count: int = 0
# Time measurements in nanoseconds
start_time_ns: int = field(default_factory=time.time_ns)
prefill_start_time_ns: Optional[int] = None
prefill_end_time_ns: Optional[int] = None
prefill_time_ns: Optional[int] = None
decode_start_time_ns: Optional[int] = None
total_time_ns: Optional[int] = None
# Timestamps and measurements
token_timestamps_ns: List[int] = field(default_factory=list)
decode_times_ns: List[int] = field(default_factory=list)
# Calculated metrics (tokens per second)
prefill_tokens_per_second: Optional[float] = None
decode_tokens_per_second: Optional[float] = None
end_to_end_tokens_per_second: Optional[float] = None
# Additional computed values
avg_decode_time: Optional[float] = None
def to_dict(self) -> Dict[str, any]:
"""Convert metrics to a dictionary for API response, with times in seconds"""
metrics = {
"input_token_count": self.input_token_count,
"total_tokens_generated": self.total_tokens_generated,
}
# Add time metrics in seconds
if self.prefill_time_ns is not None:
metrics["prefill_time"] = self.prefill_time_ns / 1e9
metrics["prefill_tokens_per_second"] = self.prefill_tokens_per_second or 0
if self.total_time_ns is not None:
metrics["total_time"] = self.total_time_ns / 1e9
if self.decode_times_ns:
metrics["avg_decode_time"] = self.avg_decode_time
metrics["decode_tokens_per_second"] = self.decode_tokens_per_second or 0
# Add throughput metrics
metrics["end_to_end_tokens_per_second"] = self.end_to_end_tokens_per_second or 0
return metrics
class LLMPerformanceMonitor:
"""Generic performance monitor for LLM inference that tracks prefill and decode
phases"""
def __init__(self, input_token_count: int = 0):
# Performance metrics
self.metrics = LLMPerformanceMetrics(input_token_count=input_token_count)
# Phase flags
self.prefill_started: bool = False
self.first_token_received: bool = False
def start_prefill(self) -> int:
"""Mark the beginning of the prefill phase using nanosecond timestamp"""
timestamp = time.time_ns()
self.metrics.prefill_start_time_ns = timestamp
self.prefill_started = True
return timestamp
def on_tokens_received(self, current_token_count: int) -> Dict[str, any]:
"""
Called when new tokens are received from LLM
Returns updated performance metrics
"""
current_time_ns = time.time_ns()
# Calculate new tokens received in this batch
new_tokens = current_token_count - self.metrics.prev_tokens_count
# Auto-detect the end of prefill / start of decode phase
if self.prefill_started and not self.first_token_received and new_tokens > 0:
# This is the first tokens batch - mark the end of prefill phase
self.metrics.prefill_end_time_ns = current_time_ns
self.metrics.prefill_time_ns = (
self.metrics.prefill_end_time_ns - self.metrics.prefill_start_time_ns
)
# Convert nanoseconds to seconds for calculation and logging
prefill_time_sec = self.metrics.prefill_time_ns / 1e9
# Calculate prefill speed
if self.metrics.input_token_count > 0 and prefill_time_sec > 0:
self.metrics.prefill_tokens_per_second = (
self.metrics.input_token_count / prefill_time_sec
)
logger.info(
f"Prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
f"tokens/s for {self.metrics.input_token_count} tokens"
)
# Mark the beginning of decode phase
self.metrics.decode_start_time_ns = current_time_ns
self.first_token_received = True
# Record token generation data
if self.first_token_received and new_tokens > 0:
# If we've already received tokens, add decode time for this batch
if len(self.metrics.token_timestamps_ns) > 0:
last_timestamp_ns = self.metrics.token_timestamps_ns[-1]
token_decode_time_ns = current_time_ns - last_timestamp_ns
# Distribute the time evenly across all new tokens in this batch
time_per_token = token_decode_time_ns / new_tokens
for _ in range(new_tokens):
self.metrics.decode_times_ns.append(time_per_token)
# Record the current token batch timestamp
self.metrics.token_timestamps_ns.append(current_time_ns)
self.metrics.total_tokens_generated += new_tokens
self.metrics.prev_tokens_count = current_token_count
# Calculate current metrics
self._update_metrics(current_time_ns)
return self.get_metrics_dict()
def _update_metrics(self, current_time_ns: Optional[int] = None) -> None:
"""Update the performance metrics based on current state"""
if current_time_ns is None:
current_time_ns = time.time_ns()
# Record total time
self.metrics.total_time_ns = current_time_ns - self.metrics.start_time_ns
# Calculate average decode speed
if self.metrics.decode_times_ns:
# Convert to seconds
decode_times_sec = [t / 1e9 for t in self.metrics.decode_times_ns]
self.metrics.avg_decode_time = sum(decode_times_sec) / len(decode_times_sec)
self.metrics.decode_tokens_per_second = 1.0 / self.metrics.avg_decode_time
# Calculate end-to-end throughput
total_time_sec = self.metrics.total_time_ns / 1e9
if total_time_sec > 0:
total_tokens = (
self.metrics.input_token_count + self.metrics.total_tokens_generated
)
self.metrics.end_to_end_tokens_per_second = total_tokens / total_time_sec
def end_generation(self) -> Dict[str, any]:
"""Mark the end of generation and finalize metrics"""
current_time_ns = time.time_ns()
self._update_metrics(current_time_ns)
# Log final performance data
total_time_sec = self.metrics.total_time_ns / 1e9
logger.info(f"Generation complete. Total time: {total_time_sec:.6f}s")
if self.metrics.prefill_tokens_per_second:
logger.info(
f"Final prefill speed: {self.metrics.prefill_tokens_per_second:.2f} "
"tokens/s"
)
if self.metrics.decode_tokens_per_second:
logger.info(
f"Final decode speed: {self.metrics.decode_tokens_per_second:.2f} "
"tokens/s"
)
if self.metrics.end_to_end_tokens_per_second:
logger.info(
"End-to-end throughput: "
f"{self.metrics.end_to_end_tokens_per_second:.2f} tokens/s"
)
return self.get_metrics_dict()
def get_metrics_dict(self) -> Dict[str, any]:
"""Get performance metrics as dictionary, converting nanoseconds to seconds
for external use"""
return self.metrics.to_dict()