diff --git a/packages/dbgpt-core/src/dbgpt/core/interface/llm.py b/packages/dbgpt-core/src/dbgpt/core/interface/llm.py index ea0f30072..efcb340b0 100644 --- a/packages/dbgpt-core/src/dbgpt/core/interface/llm.py +++ b/packages/dbgpt-core/src/dbgpt/core/interface/llm.py @@ -65,7 +65,14 @@ class ModelInferenceMetrics: """The total number of tokens (prompt plus completion).""" speed_per_second: Optional[float] = None - """The average number of tokens generated per second.""" + """The average number of tokens generated per second. Includes both prefill and + decode time.""" + + prefill_tokens_per_second: Optional[float] = None + """Prefill speed in tokens per second.""" + + decode_tokens_per_second: Optional[float] = None + """The average number of tokens generated per second during the decode phase.""" current_gpu_infos: Optional[List[GPUInfo]] = None """Current gpu information, all devices""" @@ -97,6 +104,12 @@ class ModelInferenceMetrics: completion_tokens = last_metrics.completion_tokens if last_metrics else None total_tokens = last_metrics.total_tokens if last_metrics else None speed_per_second = last_metrics.speed_per_second if last_metrics else None + prefill_tokens_per_second = ( + last_metrics.prefill_tokens_per_second if last_metrics else None + ) + decode_tokens_per_second = ( + last_metrics.decode_tokens_per_second if last_metrics else None + ) current_gpu_infos = last_metrics.current_gpu_infos if last_metrics else None avg_gpu_infos = last_metrics.avg_gpu_infos if last_metrics else None @@ -116,6 +129,8 @@ class ModelInferenceMetrics: completion_tokens=completion_tokens, total_tokens=total_tokens, speed_per_second=speed_per_second, + prefill_tokens_per_second=prefill_tokens_per_second, + decode_tokens_per_second=decode_tokens_per_second, current_gpu_infos=current_gpu_infos, avg_gpu_infos=avg_gpu_infos, ) @@ -124,6 +139,65 @@ class ModelInferenceMetrics: """Convert the model inference metrics to dict.""" return asdict(self) + def to_printable_string(self) -> str: + """Stringify the metrics in an elegant format. + + Returns: + str: A formatted string containing first token latency, prefill speed, + decode speed, prompt tokens and completion tokens. + """ + lines = [] + + # Calculate first token latency if possible + first_token_latency = None + if self.first_token_time_ms is not None and self.start_time_ms is not None: + first_token_latency = ( + self.first_token_time_ms - self.start_time_ms + ) / 1000.0 + + # Add section header + lines.append("=== Model Inference Metrics ===") + + # Latency metrics + lines.append("\n▶ Latency:") + if first_token_latency is not None: + lines.append(f" • First Token Latency: {first_token_latency:.3f}s") + else: + lines.append(" • First Token Latency: N/A") + + # Speed metrics + lines.append("\n▶ Speed:") + if self.prefill_tokens_per_second is not None: + lines.append( + f" • Prefill Speed: {self.prefill_tokens_per_second:.2f} tokens/s" + ) + else: + lines.append(" • Prefill Speed: N/A") + + if self.decode_tokens_per_second is not None: + lines.append( + f" • Decode Speed: {self.decode_tokens_per_second:.2f} tokens/s" + ) + else: + lines.append(" • Decode Speed: N/A") + + # Token counts + lines.append("\n▶ Tokens:") + if self.prompt_tokens is not None: + lines.append(f" • Prompt Tokens: {self.prompt_tokens}") + else: + lines.append(" • Prompt Tokens: N/A") + + if self.completion_tokens is not None: + lines.append(f" • Completion Tokens: {self.completion_tokens}") + else: + lines.append(" • Completion Tokens: N/A") + + if self.total_tokens is not None: + lines.append(f" • Total Tokens: {self.total_tokens}") + + return "\n".join(lines) + @dataclass @PublicAPI(stability="beta") diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/base.py b/packages/dbgpt-core/src/dbgpt/model/adapter/base.py index ffb75fee7..e93083840 100644 --- a/packages/dbgpt-core/src/dbgpt/model/adapter/base.py +++ b/packages/dbgpt-core/src/dbgpt/model/adapter/base.py @@ -223,6 +223,12 @@ class LLMModelAdapter(ABC): """Load model and tokenizer""" raise NotImplementedError + def model_patch( + self, deploy_model_params: LLMDeployModelParameters + ) -> Optional[Callable[[Any], Any]]: + """Patch function for model""" + return None + def parse_max_length(self, model, tokenizer) -> Optional[int]: """Parse the max_length of the model. diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/hf_adapter.py b/packages/dbgpt-core/src/dbgpt/model/adapter/hf_adapter.py index 83e8aeaa0..8af00b1db 100644 --- a/packages/dbgpt-core/src/dbgpt/model/adapter/hf_adapter.py +++ b/packages/dbgpt-core/src/dbgpt/model/adapter/hf_adapter.py @@ -617,6 +617,13 @@ class Qwen3Adapter(QwenAdapter): " transformers package." ) + def model_patch(self, deploy_model_params: LLMDeployModelParameters): + """Apply the monkey patch to moe model for high inference speed.""" + + from ..llm.monkey_patch import apply_qwen3_moe_monkey_patch + + return apply_qwen3_moe_monkey_patch + def is_reasoning_model( self, deploy_model_params: LLMDeployModelParameters, diff --git a/packages/dbgpt-core/src/dbgpt/model/adapter/loader.py b/packages/dbgpt-core/src/dbgpt/model/adapter/loader.py index 6165144af..c15c9eb3c 100644 --- a/packages/dbgpt-core/src/dbgpt/model/adapter/loader.py +++ b/packages/dbgpt-core/src/dbgpt/model/adapter/loader.py @@ -2,7 +2,7 @@ # -*- coding: utf-8 -*- import logging -from typing import Any, Dict, Optional, cast +from typing import Any, Callable, Dict, Optional, cast from dbgpt.core.interface.parameter import LLMDeployModelParameters from dbgpt.model.adapter.base import LLMModelAdapter @@ -146,8 +146,16 @@ def huggingface_loader( if model_params.attn_implementation: kwargs["attn_implementation"] = model_params.attn_implementation + model_patch = llm_adapter.model_patch(model_params) + model, tokenizer = _hf_try_load_default_quantization_model( - model_path, llm_adapter, device, num_gpus, model_params, kwargs + model_path, + llm_adapter, + device, + num_gpus, + model_params, + kwargs, + model_patch=model_patch, ) if model: return model, tokenizer @@ -176,7 +184,7 @@ def huggingface_loader( compress_module(model, device) return _hf_handle_model_and_tokenizer( - model, tokenizer, device, num_gpus, model_params + model, tokenizer, device, num_gpus, model_params, model_patch=model_patch ) @@ -187,6 +195,7 @@ def _hf_try_load_default_quantization_model( num_gpus: int, model_params: HFLLMDeployModelParameters, kwargs: Dict[str, Any], + model_patch: Optional[Callable[[Any], Any]] = None, ): """Try load default quantization model(Support by huggingface default)""" cloned_kwargs = {k: v for k, v in kwargs.items()} @@ -216,7 +225,13 @@ def _hf_try_load_default_quantization_model( if model: logger.info(f"Load default quantization model {model_name} success") return _hf_handle_model_and_tokenizer( - model, tokenizer, device, num_gpus, model_params, to=False + model, + tokenizer, + device, + num_gpus, + model_params, + to=False, + model_patch=model_patch, ) return None, None except Exception as e: @@ -233,6 +248,7 @@ def _hf_handle_model_and_tokenizer( num_gpus: int, model_params: HFLLMDeployModelParameters, to: bool = True, + model_patch: Optional[Callable[[Any], Any]] = None, ): if (device == "cuda" and num_gpus == 1) or device == "mps" and tokenizer: # TODO: Check cpu_offloading @@ -243,6 +259,11 @@ def _hf_handle_model_and_tokenizer( pass except AttributeError: pass + try: + if model_patch: + model = model_patch(model) + except Exception: + pass if model_params.verbose: print(model) return model, tokenizer diff --git a/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py b/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py index 9136ea6fb..160ca9e4b 100644 --- a/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py +++ b/packages/dbgpt-core/src/dbgpt/model/cluster/worker/default_worker.py @@ -197,7 +197,8 @@ class DefaultModelWorker(ModelWorker): yield model_output logger.info( f"\n\nfull stream output:\n{previous_response}\n\nmodel " - f"generate_stream params:\n{params}" + f"generate_stream params:\n{params}\n" + f"{last_metrics.to_printable_string()}" ) model_span.end(metadata={"output": previous_response}) span.end() @@ -238,6 +239,10 @@ class DefaultModelWorker(ModelWorker): last_metrics, is_first_generate, ) + last_metrics = current_metrics + logger.info( + f"generate params:\n{params}\n{last_metrics.to_printable_string()}" + ) return model_output else: for out in self.generate_stream(params): @@ -320,7 +325,8 @@ class DefaultModelWorker(ModelWorker): yield model_output logger.info( f"\n\nfull stream output:\n{previous_response}\n\nmodel " - f"generate_stream params:\n{params}" + f"generate_stream params:\n{params}\n" + f"{last_metrics.to_printable_string()}" ) model_span.end(metadata={"output": previous_response}) span.end() @@ -359,6 +365,10 @@ class DefaultModelWorker(ModelWorker): last_metrics, is_first_generate, ) + last_metrics = current_metrics + logger.info( + f"generate params:\n{params}\n{last_metrics.to_printable_string()}" + ) return model_output else: output = None @@ -576,6 +586,17 @@ def _new_metrics_from_model_output( # time cost(seconds) duration = (metrics.current_time_ms - metrics.start_time_ms) / 1000.0 metrics.speed_per_second = total_tokens / duration + if total_tokens and "prefill_tokens_per_second" in usage: + metrics.prefill_tokens_per_second = usage["prefill_tokens_per_second"] + if total_tokens and "decode_tokens_per_second" in usage: + # Decode speed + metrics.decode_tokens_per_second = usage["decode_tokens_per_second"] + elif total_tokens and metrics.first_token_time_ms: + # time cost(seconds) + duration = (metrics.current_time_ms - metrics.first_token_time_ms) / 1000.0 + if duration > 0: + # Calculate decode speed if not provided + metrics.decode_tokens_per_second = metrics.completion_tokens / duration current_gpu_infos = _get_current_cuda_memory() metrics.current_gpu_infos = current_gpu_infos diff --git a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/hf_chat_llm.py b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/hf_chat_llm.py index 2d6660b9a..1bd05c322 100644 --- a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/hf_chat_llm.py +++ b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/hf_chat_llm.py @@ -2,10 +2,11 @@ import logging from threading import Thread import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer +from transformers import AutoModelForCausalLM, AutoTokenizer from dbgpt.core import ModelOutput +from ...utils.hf_stream_utils import PerformanceMonitoringStreamer from ...utils.parse_utils import ( _DEFAULT_THINK_END_TOKEN, _DEFAULT_THINK_START_TOKEN, @@ -28,13 +29,15 @@ def huggingface_chat_generate_stream( temperature = float(params.get("temperature", 0.7)) top_p = float(params.get("top_p", 1.0)) echo = params.get("echo", False) - # max_new_tokens = int(params.get("max_new_tokens", 2048)) + max_new_tokens = int(params.get("max_new_tokens", 4096)) stop_token_ids = params.get("stop_token_ids", []) do_sample = params.get("do_sample", True) custom_stop_words = params.get("custom_stop_words", []) think_start_token = params.get("think_start_token", _DEFAULT_THINK_START_TOKEN) think_end_token = params.get("think_end_token", _DEFAULT_THINK_END_TOKEN) is_reasoning_model = params.get("is_reasoning_model", False) + use_cache = params.get("use_cache", True) + cache_implementation = params.get("cache_implementation") reasoning_patterns = [ {"start": think_start_token, "end": think_end_token}, ] @@ -52,7 +55,9 @@ def huggingface_chat_generate_stream( token_kwargs["videos"] = videos if has_media: token_kwargs["padding"] = True - tokenize_results = tokenizer(**token_kwargs).to(device) + tokenize_results = tokenizer(**token_kwargs) + input_token_count = tokenize_results.input_ids.shape[1] # Count input tokens + tokenize_results = tokenize_results.to(device) # # if model.config.is_encoder_decoder: # max_src_len = context_len @@ -63,29 +68,50 @@ def huggingface_chat_generate_stream( # # input_echo_len = len(input_ids) # input_ids = torch.as_tensor([input_ids], device=device) - streamer = TextIteratorStreamer( - tokenizer, skip_prompt=not echo, skip_special_tokens=True + streamer = PerformanceMonitoringStreamer( + tokenizer, + skip_prompt=not echo, + skip_special_tokens=True, + input_token_count=input_token_count, ) base_kwargs = { - "max_length": context_len, "temperature": temperature, "streamer": streamer, "top_p": top_p, + "use_cache": use_cache, + "max_new_tokens": max_new_tokens, } if stop_token_ids: base_kwargs["eos_token_id"] = stop_token_ids if do_sample is not None: base_kwargs["do_sample"] = do_sample + if cache_implementation: + base_kwargs["cache_implementation"] = cache_implementation logger.info( f"Predict with parameters: {base_kwargs}\ncustom_stop_words: " f"{custom_stop_words}" ) - generate_kwargs = {**tokenize_results, **base_kwargs} - thread = Thread(target=model.generate, kwargs=generate_kwargs) + + def generate_with_resilience(): + try: + _outputs = model.generate(**generate_kwargs, return_dict_in_generate=True) + except torch.cuda.OutOfMemoryError as e: + logger.warning( + f"OOM error occurred: {e}. Trying cleanup and retrying generation." + ) + torch.cuda.empty_cache() + model.generate(**generate_kwargs) + except Exception as ex: + logger.error(f"Unexpected error during generation: {ex}") + streamer.end() + raise + + streamer.start_prefill() + thread = Thread(target=generate_with_resilience) thread.start() text = "" usage = None @@ -111,6 +137,15 @@ def huggingface_chat_generate_stream( extract_reasoning=is_reasoning_model, reasoning_patterns=reasoning_patterns, ) + perf_metrics = streamer.get_performance_metrics() + usage = { + "prompt_tokens": perf_metrics["input_token_count"], + "completion_tokens": perf_metrics["total_tokens_generated"], + "total_tokens": perf_metrics["input_token_count"] + + perf_metrics["total_tokens_generated"], + } + usage.update(perf_metrics) + yield ModelOutput.build( msg.content, msg.reasoning_content, @@ -118,3 +153,4 @@ def huggingface_chat_generate_stream( usage=usage, is_reasoning_model=is_reasoning_model, ) + thread.join() diff --git a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/vllm_llm.py b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/vllm_llm.py index a49a29de0..03709b7b2 100644 --- a/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/vllm_llm.py +++ b/packages/dbgpt-core/src/dbgpt/model/llm/llm_out/vllm_llm.py @@ -7,6 +7,7 @@ from vllm.utils import random_uuid from dbgpt.core import ModelOutput +from ...utils.llm_metrics import LLMPerformanceMonitor from ...utils.parse_utils import ( _DEFAULT_THINK_END_TOKEN, _DEFAULT_THINK_START_TOKEN, @@ -85,6 +86,13 @@ async def generate_stream( ) # vocab = tokenizer.get_vocab() + # Initialize the performance monitor with estimated token count + estimated_input_tokens = len(tokenizer.encode(prompt)) + perf_monitor = LLMPerformanceMonitor(input_token_count=estimated_input_tokens) + + # Start measuring prefill phase + perf_monitor.start_prefill() + results_generator = model.generate(prompt, sampling_params, request_id) usage = None finish_reason = None @@ -101,16 +109,30 @@ async def generate_stream( completion_tokens = sum( len(output.token_ids) for output in request_output.outputs ) + # If this is the first iteration, update the input token count + if perf_monitor.metrics.input_token_count != prompt_tokens: + perf_monitor.metrics.input_token_count = prompt_tokens + + # Update performance metrics based on current token count + perf_metrics = perf_monitor.on_tokens_received(completion_tokens) + usage = { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, } + # Add performance metrics to usage + usage.update(perf_metrics) + finish_reason = ( request_output.outputs[0].finish_reason if len(request_output.outputs) == 1 else [output.finish_reason for output in request_output.outputs] ) + # Check if generation is complete + is_complete = finish_reason is not None + if is_complete: + perf_monitor.end_generation() if text_outputs: # Tempora if prompt.rstrip().endswith(think_start_token) and is_reasoning_model: diff --git a/packages/dbgpt-core/src/dbgpt/model/llm/monkey_patch.py b/packages/dbgpt-core/src/dbgpt/model/llm/monkey_patch.py index 705a1e056..8601440b1 100644 --- a/packages/dbgpt-core/src/dbgpt/model/llm/monkey_patch.py +++ b/packages/dbgpt-core/src/dbgpt/model/llm/monkey_patch.py @@ -5,6 +5,7 @@ import math from typing import Optional, Tuple import torch +import torch.nn.functional as F import transformers from torch import nn @@ -119,3 +120,96 @@ def forward( def replace_llama_attn_with_non_inplace_operations(): """Avoid bugs in mps backend by not using in-place operations.""" transformers.models.llama.modeling_llama.LlamaAttention.forward = forward + + +class ParQwen3MoeSparseMoeBlock(nn.Module): + """ + Adapted from https://huggingface.co/Qwen/Qwen3-30B-A3B/discussions/13 + """ + + def __init__(self, base_moe): + super().__init__() + self.base_moe = base_moe + self.num_experts = base_moe.num_experts + self.top_k = base_moe.top_k + self.norm_topk_prob = base_moe.norm_topk_prob + + # gating + self.gate = base_moe.gate + self.experts = base_moe.experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) + if self.norm_topk_prob: # only diff with mixtral sparse moe block! + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot( + selected_experts, num_classes=self.num_experts + ).permute(2, 1, 0) + + activated_experts = torch.unique(selected_experts) + cuda_streams = [torch.cuda.Stream() for _ in activated_experts] + # Loop over all available experts in the model and perform the computation on + # each expert + for expert_idx, cuda_stream in zip(activated_experts, cuda_streams): + with torch.cuda.stream(cuda_stream): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state + # for the current expert. We need to make sure to multiply the output + # hidden states by `routing_weights` on the corresponding tokens + # (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = ( + expert_layer(current_state) * routing_weights[top_x, idx, None] + ) + + # However `index_add_` only support torch tensors for indexing so + # we'll use the `top_x` tensor here. + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) + torch.cuda.synchronize() + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, router_logits + + +def apply_qwen3_moe_monkey_patch(model): + if not torch.cuda.is_available(): + # Only apply monkey patch if CUDA is available + return model + + for layer in model.model.layers: + if type(layer.mlp).__name__ == "Qwen3MoeSparseMoeBlock": + layer.mlp = ParQwen3MoeSparseMoeBlock(layer.mlp) + return model + + +def recovery_moe_monkey_patch(model): + for layer in model.model.layers: + if type(layer.mlp).__name__ == "ParQwen3MoeSparseMoeBlock" and hasattr( + layer.mlp, "base_moe" + ): + layer.mlp = layer.mlp.base_moe + return model diff --git a/packages/dbgpt-core/src/dbgpt/model/utils/hf_stream_utils.py b/packages/dbgpt-core/src/dbgpt/model/utils/hf_stream_utils.py new file mode 100644 index 000000000..032cba7d0 --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/model/utils/hf_stream_utils.py @@ -0,0 +1,70 @@ +import logging + +from transformers import TextIteratorStreamer + +from .llm_metrics import LLMPerformanceMonitor + +logger = logging.getLogger(__name__) + + +class PerformanceMonitoringStreamer(TextIteratorStreamer): + """ + Extended TextIteratorStreamer that monitors LLM inference performance. + Uses the generic LLMPerformanceMonitor for performance tracking. + """ + + def __init__( + self, + tokenizer, + skip_prompt=False, + timeout=None, + input_token_count=0, + **decode_kwargs, + ): + super().__init__( + tokenizer, skip_prompt=skip_prompt, timeout=timeout, **decode_kwargs + ) + + # Initialize the performance monitor + self.perf_monitor = LLMPerformanceMonitor(input_token_count=input_token_count) + + # Additional flags for streamer-specific behavior + self.is_prompt_token = True # Flag to track if current tokens are from prompt + + def start_prefill(self): + """Mark the beginning of the prefill phase""" + self.perf_monitor.start_prefill() + + def put(self, value): + """ + Receive tokens and track performance metrics. + Automatically detects prefill/decode phase transitions. + """ + # Skip counting if these are prompt tokens and skip_prompt is True + if self.skip_prompt and self.is_prompt_token: + self.is_prompt_token = False # Mark that we've processed the prompt tokens + logger.debug("Skipping prompt tokens for performance measurement") + super().put(value) # Call parent method to continue flow + return + + # Calculate number of new tokens + token_count = len(value.tolist()) + total_token_count = self.perf_monitor.metrics.prev_tokens_count + token_count + + # Update performance metrics + self.perf_monitor.on_tokens_received(total_token_count) + + # Call the parent method to continue the original flow + super().put(value) + + def end(self): + """End generation and finalize performance metrics""" + # Finalize metrics + self.perf_monitor.end_generation() + + # Call the parent method to continue the original flow + super().end() + + def get_performance_metrics(self): + """Get performance metrics in a format suitable for API responses""" + return self.perf_monitor.get_metrics_dict() diff --git a/packages/dbgpt-core/src/dbgpt/model/utils/llm_metrics.py b/packages/dbgpt-core/src/dbgpt/model/utils/llm_metrics.py new file mode 100644 index 000000000..641c4c53a --- /dev/null +++ b/packages/dbgpt-core/src/dbgpt/model/utils/llm_metrics.py @@ -0,0 +1,194 @@ +import logging +import time +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + + +@dataclass +class LLMPerformanceMetrics: + """Performance metrics for LLM inference, including prefill and decode phases""" + + # Token counts + input_token_count: int = 0 + total_tokens_generated: int = 0 + prev_tokens_count: int = 0 + + # Time measurements in nanoseconds + start_time_ns: int = field(default_factory=time.time_ns) + prefill_start_time_ns: Optional[int] = None + prefill_end_time_ns: Optional[int] = None + prefill_time_ns: Optional[int] = None + decode_start_time_ns: Optional[int] = None + total_time_ns: Optional[int] = None + + # Timestamps and measurements + token_timestamps_ns: List[int] = field(default_factory=list) + decode_times_ns: List[int] = field(default_factory=list) + + # Calculated metrics (tokens per second) + prefill_tokens_per_second: Optional[float] = None + decode_tokens_per_second: Optional[float] = None + end_to_end_tokens_per_second: Optional[float] = None + + # Additional computed values + avg_decode_time: Optional[float] = None + + def to_dict(self) -> Dict[str, any]: + """Convert metrics to a dictionary for API response, with times in seconds""" + metrics = { + "input_token_count": self.input_token_count, + "total_tokens_generated": self.total_tokens_generated, + } + + # Add time metrics in seconds + if self.prefill_time_ns is not None: + metrics["prefill_time"] = self.prefill_time_ns / 1e9 + metrics["prefill_tokens_per_second"] = self.prefill_tokens_per_second or 0 + + if self.total_time_ns is not None: + metrics["total_time"] = self.total_time_ns / 1e9 + + if self.decode_times_ns: + metrics["avg_decode_time"] = self.avg_decode_time + metrics["decode_tokens_per_second"] = self.decode_tokens_per_second or 0 + + # Add throughput metrics + metrics["end_to_end_tokens_per_second"] = self.end_to_end_tokens_per_second or 0 + + return metrics + + +class LLMPerformanceMonitor: + """Generic performance monitor for LLM inference that tracks prefill and decode + phases""" + + def __init__(self, input_token_count: int = 0): + # Performance metrics + self.metrics = LLMPerformanceMetrics(input_token_count=input_token_count) + + # Phase flags + self.prefill_started: bool = False + self.first_token_received: bool = False + + def start_prefill(self) -> int: + """Mark the beginning of the prefill phase using nanosecond timestamp""" + timestamp = time.time_ns() + self.metrics.prefill_start_time_ns = timestamp + self.prefill_started = True + return timestamp + + def on_tokens_received(self, current_token_count: int) -> Dict[str, any]: + """ + Called when new tokens are received from LLM + Returns updated performance metrics + """ + current_time_ns = time.time_ns() + + # Calculate new tokens received in this batch + new_tokens = current_token_count - self.metrics.prev_tokens_count + + # Auto-detect the end of prefill / start of decode phase + if self.prefill_started and not self.first_token_received and new_tokens > 0: + # This is the first tokens batch - mark the end of prefill phase + self.metrics.prefill_end_time_ns = current_time_ns + self.metrics.prefill_time_ns = ( + self.metrics.prefill_end_time_ns - self.metrics.prefill_start_time_ns + ) + + # Convert nanoseconds to seconds for calculation and logging + prefill_time_sec = self.metrics.prefill_time_ns / 1e9 + + # Calculate prefill speed + if self.metrics.input_token_count > 0 and prefill_time_sec > 0: + self.metrics.prefill_tokens_per_second = ( + self.metrics.input_token_count / prefill_time_sec + ) + logger.info( + f"Prefill speed: {self.metrics.prefill_tokens_per_second:.2f} " + f"tokens/s for {self.metrics.input_token_count} tokens" + ) + + # Mark the beginning of decode phase + self.metrics.decode_start_time_ns = current_time_ns + self.first_token_received = True + + # Record token generation data + if self.first_token_received and new_tokens > 0: + # If we've already received tokens, add decode time for this batch + if len(self.metrics.token_timestamps_ns) > 0: + last_timestamp_ns = self.metrics.token_timestamps_ns[-1] + token_decode_time_ns = current_time_ns - last_timestamp_ns + + # Distribute the time evenly across all new tokens in this batch + time_per_token = token_decode_time_ns / new_tokens + for _ in range(new_tokens): + self.metrics.decode_times_ns.append(time_per_token) + + # Record the current token batch timestamp + self.metrics.token_timestamps_ns.append(current_time_ns) + self.metrics.total_tokens_generated += new_tokens + self.metrics.prev_tokens_count = current_token_count + + # Calculate current metrics + self._update_metrics(current_time_ns) + + return self.get_metrics_dict() + + def _update_metrics(self, current_time_ns: Optional[int] = None) -> None: + """Update the performance metrics based on current state""" + if current_time_ns is None: + current_time_ns = time.time_ns() + + # Record total time + self.metrics.total_time_ns = current_time_ns - self.metrics.start_time_ns + + # Calculate average decode speed + if self.metrics.decode_times_ns: + # Convert to seconds + decode_times_sec = [t / 1e9 for t in self.metrics.decode_times_ns] + self.metrics.avg_decode_time = sum(decode_times_sec) / len(decode_times_sec) + self.metrics.decode_tokens_per_second = 1.0 / self.metrics.avg_decode_time + + # Calculate end-to-end throughput + total_time_sec = self.metrics.total_time_ns / 1e9 + if total_time_sec > 0: + total_tokens = ( + self.metrics.input_token_count + self.metrics.total_tokens_generated + ) + self.metrics.end_to_end_tokens_per_second = total_tokens / total_time_sec + + def end_generation(self) -> Dict[str, any]: + """Mark the end of generation and finalize metrics""" + current_time_ns = time.time_ns() + self._update_metrics(current_time_ns) + + # Log final performance data + total_time_sec = self.metrics.total_time_ns / 1e9 + logger.info(f"Generation complete. Total time: {total_time_sec:.6f}s") + + if self.metrics.prefill_tokens_per_second: + logger.info( + f"Final prefill speed: {self.metrics.prefill_tokens_per_second:.2f} " + "tokens/s" + ) + + if self.metrics.decode_tokens_per_second: + logger.info( + f"Final decode speed: {self.metrics.decode_tokens_per_second:.2f} " + "tokens/s" + ) + + if self.metrics.end_to_end_tokens_per_second: + logger.info( + "End-to-end throughput: " + f"{self.metrics.end_to_end_tokens_per_second:.2f} tokens/s" + ) + + return self.get_metrics_dict() + + def get_metrics_dict(self) -> Dict[str, any]: + """Get performance metrics as dictionary, converting nanoseconds to seconds + for external use""" + return self.metrics.to_dict()