ColossalAI/colossalai/legacy/inference/dynamic_batching/stats.py
Xu Kai fd6482ad8c
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998)

* support only tp

* enable tp

* add support for bloom (#5008)

* [refactor] refactor gptq and smoothquant llama (#5012)

* refactor gptq and smoothquant llama

* fix import error

* fix linear import torch-int

* fix smoothquant llama import error

* fix import accelerate error

* fix bug

* fix import smooth cuda

* fix smoothcuda

* [Inference Refactor] Merge chatglm2 with pp and tp (#5023)

merge chatglm with pp and tp

* [Refactor] remove useless inference code (#5022)

* remove useless code

* fix quant model

* fix test import bug

* mv original inference legacy

* fix chatglm2

* [Refactor] refactor policy search and quant type controlling in inference (#5035)

* [Refactor] refactor policy search and quant type controling in inference

* [inference] update readme (#5051)

* update readme

* update readme

* fix architecture

* fix table

* fix table

* [inference] udpate example (#5053)

* udpate example

* fix run.sh

* fix rebase bug

* fix some errors

* update readme

* add some features

* update interface

* update readme

* update benchmark

* add requirements-infer

---------

Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
2023-11-19 21:05:05 +08:00

46 lines
1.5 KiB
Python

# Adapted from https://github.com/ModelTC/lightllm
import time
class Stats:
def __init__(self, log_status, log_stats_interval) -> None:
self.log_stats = log_status
self.log_stats_interval = log_stats_interval
self.last_log_time = time.time()
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
return
def count_prompt_tokens(self, run_batch):
if self.log_stats:
tokens = run_batch.input_tokens()
self.prompt_tokens += tokens
self.all_tokens += tokens
return
def count_output_tokens(self, run_batch):
if self.log_stats:
tokens = len(run_batch.reqs)
self.output_tokens += tokens
self.all_tokens += tokens
return
def print_stats(self):
if not self.log_stats:
return
now = time.time()
if now - self.last_log_time > self.log_stats_interval:
print(
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
)
self.all_tokens = 0
self.output_tokens = 0
self.prompt_tokens = 0
self.last_log_time = now
return