mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[inference] Refactor inference architecture (#5057)
* [inference] support only TP (#4998) * support only tp * enable tp * add support for bloom (#5008) * [refactor] refactor gptq and smoothquant llama (#5012) * refactor gptq and smoothquant llama * fix import error * fix linear import torch-int * fix smoothquant llama import error * fix import accelerate error * fix bug * fix import smooth cuda * fix smoothcuda * [Inference Refactor] Merge chatglm2 with pp and tp (#5023) merge chatglm with pp and tp * [Refactor] remove useless inference code (#5022) * remove useless code * fix quant model * fix test import bug * mv original inference legacy * fix chatglm2 * [Refactor] refactor policy search and quant type controlling in inference (#5035) * [Refactor] refactor policy search and quant type controling in inference * [inference] update readme (#5051) * update readme * update readme * fix architecture * fix table * fix table * [inference] udpate example (#5053) * udpate example * fix run.sh * fix rebase bug * fix some errors * update readme * add some features * update interface * update readme * update benchmark * add requirements-infer --------- Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
This commit is contained in:
45
colossalai/legacy/inference/dynamic_batching/stats.py
Normal file
45
colossalai/legacy/inference/dynamic_batching/stats.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Adapted from https://github.com/ModelTC/lightllm
|
||||
|
||||
import time
|
||||
|
||||
|
||||
class Stats:
|
||||
def __init__(self, log_status, log_stats_interval) -> None:
|
||||
self.log_stats = log_status
|
||||
self.log_stats_interval = log_stats_interval
|
||||
self.last_log_time = time.time()
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
return
|
||||
|
||||
def count_prompt_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = run_batch.input_tokens()
|
||||
self.prompt_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def count_output_tokens(self, run_batch):
|
||||
if self.log_stats:
|
||||
tokens = len(run_batch.reqs)
|
||||
self.output_tokens += tokens
|
||||
self.all_tokens += tokens
|
||||
return
|
||||
|
||||
def print_stats(self):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self.last_log_time > self.log_stats_interval:
|
||||
print(
|
||||
f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n"
|
||||
f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s"
|
||||
)
|
||||
self.all_tokens = 0
|
||||
self.output_tokens = 0
|
||||
self.prompt_tokens = 0
|
||||
self.last_log_time = now
|
||||
return
|
Reference in New Issue
Block a user