mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
[Fix] Fix Inference Example, Tests, and Requirements (#5688)
* clean requirements * modify example inference struct * add test ci scripts * mark test_infer as submodule * rm deprecated cls & deps * import of HAS_FLASH_ATTN * prune inference tests to be run * prune triton kernel tests * increment pytest timeout mins * revert import path in openmoe
This commit is contained in:
275
examples/inference/llama/benchmark_llama.py
Normal file
275
examples/inference/llama/benchmark_llama.py
Normal file
@@ -0,0 +1,275 @@
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
CONFIG_MAP = {
|
||||
"toy": transformers.LlamaConfig(num_hidden_layers=4),
|
||||
"llama-7b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
),
|
||||
"llama-13b": transformers.LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_attention_heads=40,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=40,
|
||||
max_position_embeddings=2048,
|
||||
),
|
||||
"llama2-7b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"llama2-13b": transformers.LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_attention_heads=40,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"llama3-8b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=8,
|
||||
max_position_embeddings=8192,
|
||||
),
|
||||
"llama3-70b": transformers.LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=80,
|
||||
num_key_value_heads=8,
|
||||
max_position_embeddings=8192,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
|
||||
return input_ids
|
||||
|
||||
|
||||
def print_details_info(model_config, args, whole_end2end, total_token_num):
|
||||
msg: str = ""
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
msg += "-------Perf Summary-------\n"
|
||||
whole_avg_latency = whole_end2end / (total_token_num)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12
|
||||
if args.dtype in ["fp16", "bf16"]:
|
||||
num_bytes = 2
|
||||
else:
|
||||
num_bytes = 4
|
||||
|
||||
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
|
||||
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
|
||||
msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n"
|
||||
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
|
||||
|
||||
if torch.cuda.is_available():
|
||||
msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
|
||||
msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||||
msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||||
|
||||
print(msg)
|
||||
|
||||
|
||||
def benchmark_inference(args):
|
||||
with torch.no_grad():
|
||||
config = CONFIG_MAP[args.model]
|
||||
config.pad_token_id = config.eos_token_id
|
||||
|
||||
if args.mode != "vllm":
|
||||
if args.test_random_weight:
|
||||
model = transformers.LlamaForCausalLM(config).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
else:
|
||||
assert args.model_path, "When testing pretrained weights, the model path must be provided.'"
|
||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path).cuda()
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
|
||||
model = model.eval()
|
||||
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
elif args.dtype == "bf16":
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_length=args.seq_len + args.output_len,
|
||||
# max_new_tokens=args.max_output_len,
|
||||
)
|
||||
|
||||
if args.continous_batching:
|
||||
mbsz = args.mbsz
|
||||
else:
|
||||
mbsz = args.batch_size
|
||||
if args.mode == "colossalai":
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
max_batch_size=mbsz,
|
||||
max_input_len=args.seq_len,
|
||||
max_output_len=args.output_len,
|
||||
prefill_ratio=1.2,
|
||||
block_size=32,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=True,
|
||||
)
|
||||
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
elif args.mode == "vllm":
|
||||
engine = LLM(
|
||||
model=args.model_path,
|
||||
tokenizer="hf-internal-testing/llama-tokenizer",
|
||||
max_num_seqs=mbsz,
|
||||
dtype="float16",
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=args.output_len,
|
||||
)
|
||||
else:
|
||||
engine = model
|
||||
|
||||
data = data_gen(mbsz, args.seq_len)
|
||||
|
||||
if args.mode == "colossalai" or args.mode == "vllm":
|
||||
data = data.tolist()
|
||||
|
||||
N_WARMUP_STEPS = 2
|
||||
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(f"./tb_log_{args.batch_size}_" + args.mode),
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with ctx:
|
||||
for _ in range(N_WARMUP_STEPS):
|
||||
if args.mode == "colossalai":
|
||||
engine.generate(prompts_token_ids=data, generation_config=generation_config)
|
||||
elif args.mode == "vllm":
|
||||
engine.generate(prompt_token_ids=data, sampling_params=sampling_params)
|
||||
else:
|
||||
engine.generate(data, generation_config=generation_config)
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
whole_end2end = time.perf_counter()
|
||||
|
||||
if args.mode == "colossalai":
|
||||
for _ in range(args.batch_size // mbsz):
|
||||
output, output_tokens_list = engine.generate(
|
||||
prompts_token_ids=data, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
elif args.mode == "vllm":
|
||||
for _ in range(args.batch_size // mbsz):
|
||||
output = engine.generate(prompt_token_ids=data, sampling_params=sampling_params)
|
||||
else:
|
||||
for _ in range(args.batch_size // mbsz):
|
||||
output = engine.generate(data, generation_config=generation_config)
|
||||
|
||||
whole_end2end = time.perf_counter() - whole_end2end
|
||||
|
||||
if args.mode == "colossalai":
|
||||
total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list])
|
||||
elif args.mode == "vllm":
|
||||
total_token_num = sum([len(out.outputs[0].token_ids) for out in output])
|
||||
else:
|
||||
total_token_num = sum([len(out) for out in output])
|
||||
|
||||
print("total_token_num: ", total_token_num)
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
print(f"config:batch_size {args.batch_size}, input_len{ args.seq_len}, output_len {args.output_len}")
|
||||
print_details_info(config, args, whole_end2end, total_token_num)
|
||||
|
||||
|
||||
def hybrid_inference(rank, world_size, port, args):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
benchmark_inference(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def benchmark(args):
|
||||
spawn(hybrid_inference, nprocs=args.tp_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
default="toy",
|
||||
help="the size of model",
|
||||
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"],
|
||||
)
|
||||
parser.add_argument("--model_path", type=str, default=None, help="The pretrained weights path")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
|
||||
parser.add_argument("--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("--output_len", type=int, default=128, help="Output length")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument(
|
||||
"--test_random_weight", default=False, action="store_true", help="whether to test random weight"
|
||||
)
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
|
||||
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
default="colossalai",
|
||||
choices=["colossalai", "transformers", "vllm"],
|
||||
help="decide which inference framework to run",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-cb", "--continous_batching", default=False, action="store_true", help="enable continous batching"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
benchmark(args)
|
216
examples/inference/llama/benchmark_llama3.py
Normal file
216
examples/inference/llama/benchmark_llama3.py
Normal file
@@ -0,0 +1,216 @@
|
||||
import argparse
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024**2
|
||||
N_WARMUP_STEPS = 2
|
||||
|
||||
CONFIG_MAP = {
|
||||
"toy": transformers.LlamaConfig(num_hidden_layers=4),
|
||||
"llama-7b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
),
|
||||
"llama-13b": transformers.LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_attention_heads=40,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=40,
|
||||
max_position_embeddings=2048,
|
||||
),
|
||||
"llama2-7b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=32,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"llama2-13b": transformers.LlamaConfig(
|
||||
hidden_size=5120,
|
||||
intermediate_size=13824,
|
||||
num_attention_heads=40,
|
||||
num_hidden_layers=40,
|
||||
num_key_value_heads=40,
|
||||
max_position_embeddings=4096,
|
||||
),
|
||||
"llama3-8b": transformers.LlamaConfig(
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_attention_heads=32,
|
||||
num_hidden_layers=32,
|
||||
num_key_value_heads=8,
|
||||
max_position_embeddings=8192,
|
||||
),
|
||||
"llama3-70b": transformers.LlamaConfig(
|
||||
hidden_size=8192,
|
||||
intermediate_size=28672,
|
||||
num_attention_heads=64,
|
||||
num_hidden_layers=80,
|
||||
num_key_value_heads=8,
|
||||
max_position_embeddings=8192,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_accelerator().get_current_device())
|
||||
return input_ids.tolist()
|
||||
|
||||
|
||||
def print_details_info(model_config, whole_end2end, total_token_num, dtype, coordinator=None):
|
||||
if coordinator is None:
|
||||
coordinator = DistCoordinator()
|
||||
msg = "-------Perf Summary-------\n"
|
||||
whole_avg_latency = whole_end2end / (total_token_num)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12
|
||||
if dtype in ["fp16", "bf16"]:
|
||||
num_bytes = 2
|
||||
elif dtype == "fp32":
|
||||
num_bytes = 4
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype {dtype}")
|
||||
|
||||
msg += f"Whole batch end2end time: {whole_end2end * 1000:.2f} ms\n"
|
||||
msg += f"Whole batch per token latency: {whole_avg_latency * 1000:.2f} ms\n"
|
||||
msg += f"Throughput: {total_token_num / whole_end2end:.2f} tokens/s\n"
|
||||
msg += f"Flops: {num_parameters * num_bytes / whole_avg_latency / 1e12:.2f} TFLOPS\n"
|
||||
if torch.cuda.is_available():
|
||||
msg += f"-------Memory Summary Device:{get_accelerator().current_device()}-------\n"
|
||||
msg += f"Max memory allocated: {get_accelerator().max_memory_allocated() / GIGABYTE:.2f} GB\n"
|
||||
msg += f"Max memory reserved: {get_accelerator().max_memory_reserved() / GIGABYTE:.2f} GB\n"
|
||||
|
||||
coordinator.print_on_master(msg)
|
||||
|
||||
|
||||
def benchmark_inference(args):
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
config = CONFIG_MAP[args.model]
|
||||
config.pad_token_id = config.eos_token_id
|
||||
if args.model_path is not None:
|
||||
model = transformers.LlamaForCausalLM.from_pretrained(args.model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
else:
|
||||
# Random weights
|
||||
model = transformers.LlamaForCausalLM(config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
if args.dtype == "fp16":
|
||||
model = model.half()
|
||||
elif args.dtype == "bf16":
|
||||
model = model.to(torch.bfloat16)
|
||||
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
max_batch_size=args.batch_size,
|
||||
max_input_len=args.max_seq_len,
|
||||
max_output_len=args.max_output_len,
|
||||
prefill_ratio=1.2,
|
||||
block_size=32,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=True,
|
||||
)
|
||||
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
|
||||
data = data_gen(args.batch_size, args.max_seq_len)
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
max_length=args.max_seq_len + args.max_output_len,
|
||||
# max_new_tokens=args.max_output_len,
|
||||
)
|
||||
coordinator.print_on_master(f"Generation Config: \n{generation_config.to_dict()}")
|
||||
|
||||
ctx = (
|
||||
torch.profiler.profile(
|
||||
record_shapes=True,
|
||||
with_stack=True,
|
||||
with_modules=True,
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
f"./tb_log_{args.batch_size}_{args.max_seq_len}_{args.max_output_len}"
|
||||
),
|
||||
)
|
||||
if args.profile
|
||||
else nullcontext()
|
||||
)
|
||||
with ctx:
|
||||
for _ in range(N_WARMUP_STEPS):
|
||||
engine.generate(prompts_token_ids=data, generation_config=generation_config)
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.perf_counter()
|
||||
output, output_tokens_list = engine.generate(
|
||||
prompts_token_ids=data, generation_config=generation_config, return_token_ids=True
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
whole_end2end = time.perf_counter() - whole_end2end
|
||||
|
||||
total_token_num = sum([len(output_tokens) for output_tokens in output_tokens_list])
|
||||
coordinator.print_on_master(f"total_token_num: {total_token_num}")
|
||||
if args.nsys:
|
||||
torch.cuda.cudart().cudaProfilerStop()
|
||||
if args.profile:
|
||||
ctx.step()
|
||||
|
||||
print_details_info(model.config, whole_end2end, total_token_num, args.dtype, coordinator=coordinator)
|
||||
|
||||
|
||||
def inference(rank, world_size, port, args):
|
||||
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
benchmark_inference(args)
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def benchmark(args):
|
||||
spawn(inference, nprocs=args.tp_size, args=args)
|
||||
|
||||
|
||||
# python benchmark_llama3.py -m llama3-8b -b 16 -s 256 -o 256
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model",
|
||||
default="llama3-8b",
|
||||
help="The version of Llama model",
|
||||
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b", "llama3-8b", "llama3-70b"],
|
||||
)
|
||||
parser.add_argument("-p", "--model_path", type=str, default=None, help="The pretrained weights path")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("-s", "--max_seq_len", type=int, default=8, help="input sequence length")
|
||||
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Output length")
|
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
|
||||
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark(args)
|
81
examples/inference/llama/llama_generation.py
Normal file
81
examples/inference/llama/llama_generation.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import argparse
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.inference.config import InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.modeling.policy.nopadding_llama import NoPaddingLlamaModelInferPolicy
|
||||
|
||||
# For Llama 3, we'll use the following configuration
|
||||
MODEL_CLS = AutoModelForCausalLM
|
||||
POLICY_CLS = NoPaddingLlamaModelInferPolicy
|
||||
|
||||
|
||||
def infer(args):
|
||||
# ==============================
|
||||
# Launch colossalai, setup distributed environment
|
||||
# ==============================
|
||||
colossalai.launch_from_torch()
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ==============================
|
||||
# Load model and tokenizer
|
||||
# ==============================
|
||||
model_path_or_name = args.model
|
||||
model = MODEL_CLS.from_pretrained(model_path_or_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path_or_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
coordinator.print_on_master(f"Model Config:\n{model.config}")
|
||||
|
||||
# ==============================
|
||||
# Initialize InferenceEngine
|
||||
# ==============================
|
||||
inference_config = InferenceConfig(
|
||||
dtype=args.dtype,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
prefill_ratio=1.2,
|
||||
block_size=16,
|
||||
tp_size=args.tp_size,
|
||||
use_cuda_kernel=args.use_cuda_kernel,
|
||||
)
|
||||
coordinator.print_on_master(f"Initializing Inference Engine...")
|
||||
engine = InferenceEngine(model, tokenizer, inference_config, model_policy=POLICY_CLS(), verbose=True)
|
||||
|
||||
# ==============================
|
||||
# Generation
|
||||
# ==============================
|
||||
generation_config = GenerationConfig(
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
)
|
||||
coordinator.print_on_master(f"Generating...")
|
||||
out = engine.generate(prompts=[args.prompt], generation_config=generation_config)
|
||||
coordinator.print_on_master(out[0])
|
||||
|
||||
|
||||
# colossalai run --nproc_per_node 1 llama_generation.py -m MODEL_PATH
|
||||
if __name__ == "__main__":
|
||||
# ==============================
|
||||
# Parse Arguments
|
||||
# ==============================
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, help="Path to the model or model name")
|
||||
parser.add_argument(
|
||||
"-p", "--prompt", type=str, default="Introduce some landmarks in the United Kingdom, such as", help="Prompt"
|
||||
)
|
||||
parser.add_argument("-b", "--max_batch_size", type=int, default=1, help="Max batch size")
|
||||
parser.add_argument("-i", "--max_input_len", type=int, default=128, help="Max input length")
|
||||
parser.add_argument("-o", "--max_output_len", type=int, default=128, help="Max output length")
|
||||
parser.add_argument("-t", "--tp_size", type=int, default=1, help="Tensor Parallelism size")
|
||||
parser.add_argument("-d", "--dtype", type=str, default="fp16", help="Data type", choices=["fp16", "fp32", "bf16"])
|
||||
parser.add_argument("--use_cuda_kernel", action="store_true", help="Use CUDA kernel, use Triton by default")
|
||||
parser.add_argument("--max_length", type=int, default=32, help="Max length for generation")
|
||||
args = parser.parse_args()
|
||||
|
||||
infer(args)
|
33
examples/inference/llama/run_benchmark.sh
Executable file
33
examples/inference/llama/run_benchmark.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
ROOT=$(realpath $(dirname $0))
|
||||
echo $ROOT
|
||||
PY_SCRIPT=${ROOT}/benchmark_llama.py
|
||||
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
|
||||
mode=$1
|
||||
|
||||
mkdir -p logs
|
||||
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
|
||||
|
||||
# benchmark llama2-7b one single GPU
|
||||
for input_len in 128 512 1024; do
|
||||
for output_len in 128 256; do
|
||||
for bsz in 16 32 64; do
|
||||
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 -b ${bsz} -s ${input_len} --output_len ${output_len} --mode ${mode} --test_random_weight | tee logs/${bsz}_${input_len}_${output_len}_${mode}_${GPU}.txt
|
||||
done
|
||||
done
|
||||
done
|
4
examples/inference/llama/test_ci.sh
Normal file
4
examples/inference/llama/test_ci.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
echo "Skip the test (this test is slow)"
|
||||
|
||||
# bash ./run_benchmark.sh
|
Reference in New Issue
Block a user