mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[example] update Grok-1 inference (#5495)
* revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url
This commit is contained in:
@@ -2,8 +2,7 @@ import time
|
||||
|
||||
import torch
|
||||
from grok1_policy import Grok1ForCausalLMPolicy
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
|
||||
import colossalai
|
||||
@@ -33,11 +32,17 @@ if __name__ == "__main__":
|
||||
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
|
||||
)
|
||||
model, *_ = booster.boost(model)
|
||||
sp = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
model.eval()
|
||||
init_time = time.time() - start
|
||||
|
||||
# A transformers-compatible version of the grok-1 tokenizer by Xenova
|
||||
# https://huggingface.co/Xenova/grok-1-tokenizer
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
|
||||
|
||||
for text in args.text:
|
||||
output = inference(
|
||||
model.unwrap(),
|
||||
sp,
|
||||
tokenizer,
|
||||
text,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=args.do_sample,
|
||||
@@ -46,5 +51,14 @@ if __name__ == "__main__":
|
||||
top_p=args.top_p,
|
||||
)
|
||||
if coordinator.is_master():
|
||||
print_output(text, sp.decode(output))
|
||||
coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")
|
||||
print_output(text, tokenizer.decode(output))
|
||||
|
||||
overall_time = time.time() - start
|
||||
gen_latency = overall_time - init_time
|
||||
avg_gen_latency = gen_latency / len(args.text)
|
||||
coordinator.print_on_master(
|
||||
f"Initializing time: {init_time:.2f} seconds.\n"
|
||||
f"Overall time: {overall_time:.2f} seconds. \n"
|
||||
f"Generation latency: {gen_latency:.2f} seconds. \n"
|
||||
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
|
||||
)
|
||||
|
Reference in New Issue
Block a user