mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-03 23:48:41 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			46 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			46 lines
		
	
	
		
			1.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import time
 | 
						|
 | 
						|
import torch
 | 
						|
from transformers import AutoModelForCausalLM, AutoTokenizer
 | 
						|
from utils import get_default_parser, inference, print_output
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    parser = get_default_parser()
 | 
						|
    args = parser.parse_args()
 | 
						|
    start = time.time()
 | 
						|
    torch.set_default_dtype(torch.bfloat16)
 | 
						|
 | 
						|
    tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
 | 
						|
 | 
						|
    model = AutoModelForCausalLM.from_pretrained(
 | 
						|
        args.pretrained,
 | 
						|
        trust_remote_code=True,
 | 
						|
        device_map="auto",
 | 
						|
        torch_dtype=torch.bfloat16,
 | 
						|
    )
 | 
						|
    model.eval()
 | 
						|
    init_time = time.time() - start
 | 
						|
 | 
						|
    for text in args.text:
 | 
						|
        output = inference(
 | 
						|
            model,
 | 
						|
            tokenizer,
 | 
						|
            text,
 | 
						|
            max_new_tokens=args.max_new_tokens,
 | 
						|
            do_sample=args.do_sample,
 | 
						|
            temperature=args.temperature,
 | 
						|
            top_k=args.top_k,
 | 
						|
            top_p=args.top_p,
 | 
						|
        )
 | 
						|
        print_output(text, tokenizer.decode(output))
 | 
						|
 | 
						|
    overall_time = time.time() - start
 | 
						|
    gen_latency = overall_time - init_time
 | 
						|
    avg_gen_latency = gen_latency / len(args.text)
 | 
						|
    print(
 | 
						|
        f"Initializing time: {init_time:.2f} seconds.\n"
 | 
						|
        f"Overall time: {overall_time:.2f} seconds. \n"
 | 
						|
        f"Generation latency: {gen_latency:.2f} seconds. \n"
 | 
						|
        f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
 | 
						|
    )
 |