mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-10-31 05:49:56 +00:00 
			
		
		
		
	* refactor latest code * update api * add dummy dataset * update Readme * add setup * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update files * add PP support * update arguments * update argument * reorg folder * update version * remove IB infor * update utils * update readme * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update save for zero * update save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add apex * update --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
		
			
				
	
	
		
			77 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			77 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import argparse
 | |
| 
 | |
| import torch
 | |
| from colossal_llama.dataset.conversation import default_conversation
 | |
| from transformers import AutoModelForCausalLM, AutoTokenizer
 | |
| 
 | |
| from colossalai.logging import get_dist_logger
 | |
| 
 | |
| logger = get_dist_logger()
 | |
| 
 | |
| 
 | |
| def load_model(model_path, device="cuda", **kwargs):
 | |
|     logger.info("Please check whether the tokenizer and model weights are properly stored in the same folder.")
 | |
|     model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
 | |
|     model.to(device)
 | |
| 
 | |
|     try:
 | |
|         tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left")
 | |
|     except OSError:
 | |
|         raise ImportError("Tokenizer not found. Please check if the tokenizer exists or the model path is correct.")
 | |
| 
 | |
|     return model, tokenizer
 | |
| 
 | |
| 
 | |
| @torch.inference_mode()
 | |
| def generate(args):
 | |
|     model, tokenizer = load_model(model_path=args.model_path, device=args.device)
 | |
| 
 | |
|     if args.prompt_style == "sft":
 | |
|         conversation = default_conversation.copy()
 | |
|         conversation.append_message("Human", args.input_txt)
 | |
|         conversation.append_message("Assistant", None)
 | |
|         input_txt = conversation.get_prompt()
 | |
|     else:
 | |
|         BASE_INFERENCE_SUFFIX = "\n\n->\n\n"
 | |
|         input_txt = f"{args.input_txt}{BASE_INFERENCE_SUFFIX}"
 | |
| 
 | |
|     inputs = tokenizer(input_txt, return_tensors="pt").to(args.device)
 | |
|     num_input_tokens = inputs["input_ids"].shape[-1]
 | |
|     output = model.generate(
 | |
|         **inputs,
 | |
|         max_new_tokens=args.max_new_tokens,
 | |
|         do_sample=args.do_sample,
 | |
|         temperature=args.temperature,
 | |
|         top_k=args.top_k,
 | |
|         top_p=args.top_p,
 | |
|         num_return_sequences=1,
 | |
|     )
 | |
|     response = tokenizer.decode(output.cpu()[0, num_input_tokens:], skip_special_tokens=True)
 | |
|     logger.info(f"\nHuman: {args.input_txt} \n\nAssistant: \n{response}")
 | |
|     return response
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(description="Colossal-LLaMA-2 inference Process.")
 | |
|     parser.add_argument(
 | |
|         "--model_path",
 | |
|         type=str,
 | |
|         default="hpcai-tech/Colossal-LLaMA-2-7b-base",
 | |
|         help="HF repo name or local path of the model",
 | |
|     )
 | |
|     parser.add_argument("--device", type=str, default="cuda:0", help="Set the device")
 | |
|     parser.add_argument(
 | |
|         "--max_new_tokens",
 | |
|         type=int,
 | |
|         default=512,
 | |
|         help=" Set maximum numbers of tokens to generate, ignoring the number of tokens in the prompt",
 | |
|     )
 | |
|     parser.add_argument("--do_sample", type=bool, default=True, help="Set whether or not to use sampling")
 | |
|     parser.add_argument("--temperature", type=float, default=0.3, help="Set temperature value")
 | |
|     parser.add_argument("--top_k", type=int, default=50, help="Set top_k value for top-k-filtering")
 | |
|     parser.add_argument("--top_p", type=float, default=0.95, help="Set top_p value for generation")
 | |
|     parser.add_argument("--input_txt", type=str, default="明月松间照,", help="The prompt input to the model")
 | |
|     parser.add_argument("--prompt_style", choices=["sft", "pretrained"], default="sft", help="The style of the prompt")
 | |
|     args = parser.parse_args()
 | |
|     generate(args)
 |