[Inference]Adapted to the triton attn kernels (#5264)

* adapted to the triton attn kernels

* fix pad input

* adapted to copy_kv_to_blocked_cache

* fix ci test

* update kv memcpy

* remove print
This commit is contained in:
yuehuayingxueluo
2024-01-17 16:03:10 +08:00
committed by GitHub
parent 0f2b46a41c
commit 86b63f720c
7 changed files with 221 additions and 101 deletions

View File

@@ -1,13 +1,16 @@
import argparse
import time
from contextlib import nullcontext
import torch
import torch.distributed as dist
import transformers
from transformers import AutoTokenizer, GenerationConfig
import colossalai
import colossalai.utils.device as device_utils
from colossalai.inference import InferenceEngine
from colossalai.inference.config import InferenceConfig
from colossalai.inference.core.engine import InferenceEngine
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.utils.device import get_current_device
@@ -53,36 +56,14 @@ CONFIG_MAP = {
def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device())
attention_mask = torch.ones_like(input_ids)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
return data
return input_ids
def print_details_info(outputs, model_config, args, whole_end2end):
def print_details_info(model_config, args, whole_end2end):
msg: str = ""
if dist.get_rank() == 0:
msg += "-------Perf Summary-------\n"
if args.verbose:
timestamps = outputs[1]
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
)
end2end.append(timestamp[-1] - timestamp[0])
mb_avg_end2end = sum(end2end) / len(end2end)
mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size)
msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n"
msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n"
msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n"
msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n"
whole_avg_latency = whole_end2end / (args.output_len * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
@@ -105,35 +86,87 @@ def print_details_info(outputs, model_config, args, whole_end2end):
def benchmark_inference(args):
config = CONFIG_MAP[args.model]
model = transformers.LlamaForCausalLM(config)
if dist.get_rank() == 0:
print("Model loaded")
engine = InferenceEngine(
pp_size=args.pp_size,
tp_size=args.tp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
model=model,
verbose=args.verbose,
max_batch_size=args.batch_size,
max_input_len=args.seq_len,
max_output_len=args.output_len,
)
data = data_gen(args.batch_size, args.seq_len)
with torch.no_grad():
config = CONFIG_MAP[args.model]
config.pad_token_id = config.eos_token_id
model = transformers.LlamaForCausalLM(config).cuda()
model = model.eval()
tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/")
N_WARMUP_STEPS = 2
if args.dtype == "fp16":
model = model.half()
elif args.dtype == "bf16":
model = model.to(torch.bfloat16)
for _ in range(N_WARMUP_STEPS):
engine.generate(data)
# mbsz = args.mbsz
mbsz = args.batch_size
if args.mode == "caiinference":
inference_config = InferenceConfig(
dtype=args.dtype,
micro_batch_size=args.mb_size,
max_batch_size=mbsz,
max_input_len=args.seq_len,
max_output_len=args.output_len,
prefill_ratio=1.2,
)
engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
else:
engine = model
torch.cuda.synchronize()
whole_end2end = time.time()
outputs = engine.generate(data)
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end
data = data_gen(mbsz, args.seq_len)
generation_config = GenerationConfig(
pad_token_id=tokenizer.pad_token_id,
max_new_tokens=args.output_len,
)
print_details_info(outputs, model.config, args, whole_end2end)
N_WARMUP_STEPS = 2
ctx = (
torch.profiler.profile(
record_shapes=True,
with_stack=True,
with_modules=True,
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log_" + args.mode),
)
if args.profile
else nullcontext()
)
with ctx:
for _ in range(N_WARMUP_STEPS):
if args.mode == "caiinference":
engine.add_request(prompts_token_ids=data)
engine.generate(generation_config)
else:
engine.generate(data, generation_config=generation_config)
if args.profile:
ctx.step()
if args.nsys:
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
whole_end2end = time.perf_counter()
if args.mode == "caiinference":
for _ in range(args.batch_size // mbsz):
engine.add_request(prompts_token_ids=data)
engine.generate(generation_config)
else:
for _ in range(args.batch_size // mbsz):
engine.generate(data, generation_config=generation_config)
whole_end2end = time.perf_counter() - whole_end2end
if args.nsys:
torch.cuda.cudart().cudaProfilerStop()
if args.profile:
ctx.step()
print_details_info(model.config, args, whole_end2end)
def hybrid_inference(rank, world_size, port, args):
@@ -157,12 +190,21 @@ if __name__ == "__main__":
choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"],
)
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step")
parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length")
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument("--pp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--tp_size", type=int, default=1, help="pipeline size")
parser.add_argument("--output_len", type=int, default=128, help="Output length")
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"])
parser.add_argument("-v", "--verbose", default=False, action="store_true")
parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler")
parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler")
parser.add_argument(
"--mode",
default="caiinference",
choices=["caiinference", "transformers"],
help="decide which inference framework to run",
)
args = parser.parse_args()
benchmark(args)

View File

@@ -1,15 +1,33 @@
ROOT=$(realpath $(dirname $0))
PY_SCRIPT=${ROOT}/benchmark_llama.py
GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1)
mode=$1
mkdir -p logs
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() {
local n=${1:-"9999"}
echo "GPU Memory Usage:"
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
| tail -n +2 \
| nl -v 0 \
| tee /dev/tty \
| sort -g -k 2 \
| awk '{print $1}' \
| head -n $n)
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
echo "Now CUDA_VISIBLE_DEVICES is set to:"
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1
# benchmark llama2-7b one single GPU
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt
done
for bsz in 4 8 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt
for bsz in 16 32 64; do
python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt
done