ColossalAI/examples/language/grok-1/inference_tp.py
Hongxin Liu 7f8b16635b
[misc] refactor launch API and tensor constructor (#5666)
* [misc] remove config arg from initialize

* [misc] remove old tensor contrusctor

* [plugin] add npu support for ddp

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [devops] fix doc test ci

* [test] fix test launch

* [doc] update launch doc

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2024-04-29 10:40:11 +08:00

64 lines
2.1 KiB
Python

import time
import torch
from grok1_policy import Grok1ForCausalLMPolicy
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_default_parser, inference, print_output
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.utils import get_current_device
if __name__ == "__main__":
parser = get_default_parser()
args = parser.parse_args()
start = time.time()
colossalai.launch_from_torch()
coordinator = DistCoordinator()
plugin = HybridParallelPlugin(
tp_size=coordinator.world_size,
pp_size=1,
precision="bf16",
parallel_output=False,
custom_policy=Grok1ForCausalLMPolicy(),
)
booster = Booster(plugin=plugin)
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
with LazyInitContext(default_device=get_current_device()):
model = AutoModelForCausalLM.from_pretrained(
args.pretrained, trust_remote_code=True, torch_dtype=torch.bfloat16
)
model, *_ = booster.boost(model)
model.eval()
init_time = time.time() - start
for text in args.text:
output = inference(
model.unwrap(),
tokenizer,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
if coordinator.is_master():
print_output(text, tokenizer.decode(output))
overall_time = time.time() - start
gen_latency = overall_time - init_time
avg_gen_latency = gen_latency / len(args.text)
coordinator.print_on_master(
f"Initializing time: {init_time:.2f} seconds.\n"
f"Overall time: {overall_time:.2f} seconds. \n"
f"Generation latency: {gen_latency:.2f} seconds. \n"
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
)