[example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url
This commit is contained in:
Yuanheng Zhao
2024-03-24 20:24:11 +08:00
committed by GitHub
parent 6df844b8c4
commit 5fcd7795cd
7 changed files with 69 additions and 43 deletions

View File

@@ -2,8 +2,7 @@ import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from transformers import AutoModelForCausalLM, LlamaTokenizerFast
from utils import get_defualt_parser, inference, print_output
import colossalai
@@ -33,11 +32,17 @@ if __name__ == "__main__":
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
)
model, *_ = booster.boost(model)
sp = SentencePieceProcessor(model_file=args.tokenizer)
model.eval()
init_time = time.time() - start
# A transformers-compatible version of the grok-1 tokenizer by Xenova
# https://huggingface.co/Xenova/grok-1-tokenizer
tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer")
for text in args.text:
output = inference(
model.unwrap(),
sp,
tokenizer,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
@@ -46,5 +51,14 @@ if __name__ == "__main__":
top_p=args.top_p,
)
if coordinator.is_master():
print_output(text, sp.decode(output))
coordinator.print_on_master(f"Overall time: {time.time() - start} seconds.")
print_output(text, tokenizer.decode(output))
overall_time = time.time() - start
gen_latency = overall_time - init_time
avg_gen_latency = gen_latency / len(args.text)
coordinator.print_on_master(
f"Initializing time: {init_time:.2f} seconds.\n"
f"Overall time: {overall_time:.2f} seconds. \n"
f"Generation latency: {gen_latency:.2f} seconds. \n"
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
)