mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-18 19:58:17 +00:00
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
46 lines
1.5 KiB
Python
46 lines
1.5 KiB
Python
# Adapted from https://github.com/ModelTC/lightllm
|
|
|
|
import time
|
|
|
|
|
|
class Stats:
|
|
def __init__(self, log_status, log_stats_interval) -> None:
|
|
self.log_stats = log_status
|
|
self.log_stats_interval = log_stats_interval
|
|
self.last_log_time = time.time()
|
|
self.all_tokens = 0
|
|
self.output_tokens = 0
|
|
self.prompt_tokens = 0
|
|
return
|
|
|
|
def count_prompt_tokens(self, run_batch):
|
|
if self.log_stats:
|
|
tokens = run_batch.input_tokens()
|
|
self.prompt_tokens += tokens
|
|
self.all_tokens += tokens
|
|
return
|
|
|
|
def count_output_tokens(self, run_batch):
|
|
if self.log_stats:
|
|
tokens = len(run_batch.reqs)
|
|
self.output_tokens += tokens
|
|
self.all_tokens += tokens
|
|
return
|
|
|
|
def print_stats(self):
|
|
if not self.log_stats:
|
|
return
|
|
|
|
now = time.time()
|
|
if now - self.last_log_time > self.log_stats_interval:
|
|
print(
|
|
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
|
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
|
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
|
|
)
|
|
self.all_tokens = 0
|
|
self.output_tokens = 0
|
|
self.prompt_tokens = 0
|
|
self.last_log_time = now
|
|
return
|