[example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme
This commit is contained in:
Hongxin Liu
2024-03-21 18:07:22 +08:00
committed by GitHub
parent d158fc0e64
commit 848a574c26
9 changed files with 297 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
import time
import torch
from sentencepiece import SentencePieceProcessor
from transformers import AutoModelForCausalLM
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
)
sp = SentencePieceProcessor(model_file=args.tokenizer)
for text in args.text:
output = inference(
model,
sp,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print_output(text, sp.decode(output))
print(f"Overall time: {time.time() - start} seconds.")