mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[example] add grok-1 inference (#5485)
* [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme
This commit is contained in:
32
examples/language/grok-1/inference.py
Normal file
32
examples/language/grok-1/inference.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import time
|
||||
|
||||
import torch
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import AutoModelForCausalLM
|
||||
from utils import get_defualt_parser, inference, print_output
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_defualt_parser()
|
||||
args = parser.parse_args()
|
||||
start = time.time()
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
trust_remote_code=True,
|
||||
device_map="auto",
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
sp = SentencePieceProcessor(model_file=args.tokenizer)
|
||||
for text in args.text:
|
||||
output = inference(
|
||||
model,
|
||||
sp,
|
||||
text,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
do_sample=args.do_sample,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
print_output(text, sp.decode(output))
|
||||
print(f"Overall time: {time.time() - start} seconds.")
|
Reference in New Issue
Block a user