From 677cbfacf8ef11f423ec1f5216083675615ab85d Mon Sep 17 00:00:00 2001
From: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Date: Thu, 30 May 2024 13:48:46 +0800
Subject: [PATCH] [Fix/Example] Fix Llama Inference Loading Data Type (#5763)

* [fix/example] fix llama inference loading dtype

* revise loading dtype of benchmark llama3
---
 examples/inference/llama/benchmark_llama3.py | 12 +++++++++++-
 examples/inference/llama/llama_generation.py |  9 ++++++++-
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/examples/inference/llama/benchmark_llama3.py b/examples/inference/llama/benchmark_llama3.py
index 07ebdb2b1..76d9c6a42 100644
--- a/examples/inference/llama/benchmark_llama3.py
+++ b/examples/inference/llama/benchmark_llama3.py
@@ -17,6 +17,13 @@ GIGABYTE = 1024**3
 MEGABYTE = 1024**2
 N_WARMUP_STEPS = 2
 
+TORCH_DTYPE_MAP = {
+    "fp16": torch.float16,
+    "fp32": torch.float32,
+    "bf16": torch.bfloat16,
+}
+
+
 CONFIG_MAP = {
     "toy": transformers.LlamaConfig(num_hidden_layers=4),
     "llama-7b": transformers.LlamaConfig(
@@ -104,10 +111,13 @@ def print_details_info(model_config, whole_end2end, total_token_num, dtype, coor
 def benchmark_inference(args):
     coordinator = DistCoordinator()
 
+    torch_dtype = TORCH_DTYPE_MAP.get(args.dtype, None)
     config = CONFIG_MAP[args.model]
+    config.torch_dtype = torch_dtype
     config.pad_token_id = config.eos_token_id
+
     if args.model_path is not None:
-        model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
+        model = transformers.LlamaForCausalLM.from_pretrained(args.model_path, torch_dtype=torch_dtype)
         tokenizer = AutoTokenizer.from_pretrained(args.model_path)
     else:
         # Random weights
diff --git a/examples/inference/llama/llama_generation.py b/examples/inference/llama/llama_generation.py
index c0a1a585a..9326f717c 100644
--- a/examples/inference/llama/llama_generation.py
+++ b/examples/inference/llama/llama_generation.py
@@ -1,5 +1,6 @@
 import argparse
 
+from torch import bfloat16, float16, float32
 from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
 
 import colossalai
@@ -12,6 +13,12 @@ from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaM
 MODEL_CLS = AutoModelForCausalLM
 POLICY_CLS = NoPaddingLlamaModelInferPolicy
 
+TORCH_DTYPE_MAP = {
+    "fp16": float16,
+    "fp32": float32,
+    "bf16": bfloat16,
+}
+
 
 def infer(args):
     # ==============================
@@ -24,7 +31,7 @@ def infer(args):
     # Load model and tokenizer
     # ==============================
     model_path_or_name = args.model
-    model = MODEL_CLS.from_pretrained(model_path_or_name)
+    model = MODEL_CLS.from_pretrained(model_path_or_name, torch_dtype=TORCH_DTYPE_MAP.get(args.dtype, None))
     tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
     tokenizer.pad_token = tokenizer.eos_token
     # coordinator.print_on_master(f"Model Config:\n{model.config}")