mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-24 17:33:39 +00:00
[example] Update Inference Example (#5725)
* [example] update inference example
This commit is contained in:
@@ -27,7 +27,7 @@ def infer(args):
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
coordinator.print_on_master(f"Model Config:\n{model.config}")
|
||||
# coordinator.print_on_master(f"Model Config:\n{model.config}")
|
||||
|
||||
# ==============================
|
||||
# Initialize InferenceEngine
|
||||
@@ -52,20 +52,39 @@ def infer(args):
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
do_sample=args.do_sample,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
)
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||
coordinator.print_on_master(out[0])
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
# ==============================
|
||||
# Optionally, load drafter model and proceed speculative decoding
|
||||
# ==============================
|
||||
drafter_model_path_or_name = args.drafter_model
|
||||
if drafter_model_path_or_name is not None:
|
||||
drafter_model = AutoModelForCausalLM.from_pretrained(drafter_model_path_or_name)
|
||||
# turn on speculative decoding with the drafter model
|
||||
engine.enable_spec_dec(drafter_model)
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||
coordinator.print_on_master(out)
|
||||
|
||||
engine.disable_spec_dec()
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
|
||||
# colossalai run --nproc_per_node 2 llama_generation.py -m MODEL_PATH --tp_size 2
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
||||
parser.add_argument("--drafter_model", type=str, help="Path to the drafter model or model name")
|
||||
parser.add_argument(
|
||||
"-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt"
|
||||
)
|
||||
@@ -75,7 +94,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||
parser.add_argument("--max_length", type=int, default=32, help="Max length for generation")
|
||||
# Generation configs
|
||||
parser.add_argument("--max_length", type=int, default=64, help="Max length for generation")
|
||||
parser.add_argument("--do_sample", action="store_true", help="Use sampling for generation")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for generation")
|
||||
parser.add_argument("--top_k", type=int, default=50, help="Top k for generation")
|
||||
parser.add_argument("--top_p", type=float, default=1.0, help="Top p for generation")
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
||||
|
||||
Reference in New Issue
Block a user