diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 84810a82c..c62094f9c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -236,6 +236,7 @@ class InferenceEngine: output_list = [] batch = self.request_handler.schedule() + # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. logits = self.model( batch, self.k_cahce, diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 55e1d7aef..99d6b3b85 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -57,9 +57,6 @@ class RunningList: def is_empty(self): return not self.decoding and not self.prefill - def total_seq_num(self): - return len(self.decoding) + len(self.prefill) - class RequestHandler: """ @@ -81,6 +78,7 @@ class RequestHandler: device = torch.cuda.current_device() self.running_batch = BatchInfo(is_prompts=False, device=device) self.prefill_batch = BatchInfo(is_prompts=True, device=device) + self.max_batch_size = inference_config.max_batch_size def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) @@ -108,20 +106,18 @@ class RequestHandler: ) self.abort_sequence(seq.request_id) break - - # stop feeding new sequence into running list to assure - if self.cache_manager.num_available_blocks <= self.running_list.total_seq_num: - break - # Try to allocate cache blocks for the sequence. - if self.cache_manager.check_allocation(seq): + if ( + self.cache_manager.check_allocation(seq) + and (len(self.running_list.prefill) + len(self.running_list.decoding)) + < self.max_batch_size # There some bugs in continous batching, so we disable it here. + ): # If succeed, add the sequence to running list. remove_list.append(seq) self.running_list.append(seq) self.cache_manager.allocate_context_from_block_table(seq.block_table, seq.input_len) for seq in remove_list: lst.remove(seq) - if self.running_list.ready_for_prefill(): for seq in self.running_list.prefill: seq.mark_running() @@ -130,12 +126,7 @@ class RequestHandler: if not self.running_batch.is_empty: for seq in self.running_batch.sequences_set: - recycle = self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) - if recycle: - seq.recycle() - self.running_batch.remove(seq) - self.waiting_list[-1].append(seq) - # the recycled sequences are handled with highest priority. + self.cache_manager.allocate_token_from_block_table(seq.block_table, seq.sentence_len) return self.running_batch diff --git a/colossalai/inference/modeling/layers/attention.py b/colossalai/inference/modeling/layers/attention.py index e1bd935e9..41e50f40d 100644 --- a/colossalai/inference/modeling/layers/attention.py +++ b/colossalai/inference/modeling/layers/attention.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from transformers.modeling_attn_mask_utils import AttentionMaskConverter +@torch.no_grad def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): """ Func: copy key/value into key/value cache. @@ -40,6 +41,7 @@ def copy_to_cache(source, cache, lengths, block_tables, type: str = "prefill"): return cache +@torch.no_grad def convert_kvcache(cache, lengths, block_tables, pad_id=0): """ Func: convert key/value cache for calculation @@ -79,6 +81,7 @@ class PagedAttention: """ @staticmethod + @torch.no_grad def pad_and_reshape(tensor, seq_lengths, max_seq_len, num_heads, head_size): """ Transform 1D no_pad tensor into 2D padded tensor with shape [bsz,seq_len,num_heads,head_size] @@ -94,12 +97,14 @@ class PagedAttention: return padded_tensor @staticmethod + @torch.no_grad def generate_padding_mask(lengths, max_seq_len): range_tensor = torch.arange(max_seq_len).expand(len(lengths), max_seq_len) padding_mask = range_tensor < lengths.unsqueeze(1) return padding_mask @staticmethod + @torch.no_grad def repeat_kv(hidden_states: torch.Tensor, n_rep: int = 1) -> torch.Tensor: """ Essential component for MQA. Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). @@ -117,6 +122,7 @@ class PagedAttention: return hidden_states.reshape(batch, num_attention_heads, seq_len, head_dim) @staticmethod + @torch.no_grad def nopad_context_forward( q: torch.Tensor, # [num_tokens, num_heads, head_size] k: torch.Tensor, # [num_tokens, num_kv_heads, head_size] @@ -185,6 +191,7 @@ class PagedAttention: return attn_output @staticmethod + @torch.no_grad def pad_context_forward( q: torch.Tensor, # [batch_size, seq_len, num_heads, head_size] k: torch.Tensor, # [batch_size, seq_len, num_kv_heads, head_size] @@ -239,11 +246,10 @@ class PagedAttention: attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1) - del attn_weights - return attn_output @staticmethod + @torch.no_grad def pad_decoding_forward( q: torch.Tensor, # [bsz, 1, num_heads, head_size] k: torch.Tensor, # [bsz, 1, num_kv_heads, head_size] @@ -297,11 +303,10 @@ class PagedAttention: raise ValueError(f"Got wrong attn_output, should be in shape {(bsz,num_heads,1,head_size)}.") attn_output = attn_output.transpose(1, 2).contiguous().reshape(bsz, 1, -1) - del attn_weights - return attn_output @staticmethod + @torch.no_grad def no_pad_decoding_forward( self, q: torch.Tensor, # [num_tokens, num_heads, head_size] diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index d41267138..bbdb2f407 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -2,19 +2,23 @@ from typing import List, Optional, Tuple import torch -from transformers.models.llama.modeling_llama import ( - LlamaAttention, - LlamaDecoderLayer, - LlamaForCausalLM, - LlamaModel, - repeat_kv, -) +from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.struct import BatchInfo +from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd +from colossalai.logging import get_dist_logger from flash_attn.bert_padding import index_first_axis, pad_input # noqa +logger = get_dist_logger(__name__) + +try: + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -35,6 +39,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +@torch.no_grad() def llama_causal_lm_forward( self: LlamaForCausalLM, batch: BatchInfo = None, @@ -54,6 +59,7 @@ def llama_causal_lm_forward( return logits +@torch.no_grad() def llama_model_forward( self: LlamaModel, batch: BatchInfo = None, @@ -63,15 +69,30 @@ def llama_model_forward( ): input_ids = batch.get_batch_inputs() block_tables = batch.get_block_table_tensor() - sequence_lengths = batch.get_sequence_lengths() - attention_mask = batch.get_attn_mask(padding_id) - if batch.is_prompts: - # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. - position_ids = generate_padding_position_id(attention_mask) + if attention_mask is not None: + # TODO After the nopad version is implemented, we will use the following code to get sequence_lengths. + # sequence_lengths = batch.get_sequence_lengths() + sequence_lengths = attention_mask.sum(dim=-1, dtype=torch.int32) else: - position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) + sequence_lengths = batch.get_sequence_lengths() + + kv_seq_len = sequence_lengths.max().item() + + if attention_mask is not None: + if batch.is_prompts: + # Here, we generate position_ids through the input tensor, which can align with the output precision of the transformer. + position_ids = generate_padding_position_id(attention_mask) + else: + position_ids = (attention_mask.sum(dim=-1) - 1).reshape(-1, 1) + else: + if batch.is_prompts: + position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=batch.device) + position_ids = position_ids.unsqueeze(0) + else: + position_ids = torch.arange(kv_seq_len - 1, kv_seq_len, dtype=torch.long, device=batch.device) + position_ids = position_ids.unsqueeze(0) hidden_states = self.embed_tokens(input_ids) @@ -85,13 +106,14 @@ def llama_model_forward( is_prompts=batch.is_prompts, sequence_lengths=sequence_lengths, attention_mask=attention_mask, + kv_seq_len=kv_seq_len, ) hidden_states = self.norm(hidden_states) - return hidden_states +@torch.no_grad() def llama_decoder_layer_forward( self: LlamaDecoderLayer, hidden_states: torch.Tensor, @@ -102,6 +124,7 @@ def llama_decoder_layer_forward( is_prompts: bool = True, sequence_lengths: int = None, attention_mask: torch.Tensor = None, + kv_seq_len: int = 0, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -116,6 +139,7 @@ def llama_decoder_layer_forward( is_prompts=is_prompts, sequence_lengths=sequence_lengths, attention_mask=attention_mask, + kv_seq_len=kv_seq_len, ) hidden_states = residual + hidden_states @@ -130,6 +154,7 @@ def llama_decoder_layer_forward( # Replace transformers.models.llama.modeling_llama.LlamaAttention.forward +@torch.no_grad() def llama_attn_forward( self: LlamaAttention, hidden_states: torch.Tensor, @@ -140,6 +165,7 @@ def llama_attn_forward( is_prompts: bool = True, sequence_lengths: torch.Tensor = None, attention_mask: torch.Tensor = None, + kv_seq_len: int = 0, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -147,26 +173,44 @@ def llama_attn_forward( key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = sequence_lengths[0].item() - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) - query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) + _, _, _, block_size = k_cache.shape + if is_prompts: - attn_output = PagedAttention.pad_context_forward( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask - ) + if HAS_TRITON: + if attention_mask is not None: + query_states, key_states, value_states, indices = unpading_input( + query_states, key_states, value_states, attention_mask + ) + else: + query_states = query_states.view(-1, self.num_heads, self.head_dim) + key_states = key_states.view(-1, self.num_heads, self.head_dim) + value_states = value_states.view(-1, self.num_heads, self.head_dim) + + attn_output = context_attention_unpadded( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + ) + if attention_mask is not None: + attn_output = pad_input(attn_output, indices, bsz, q_len) + else: + attn_output = PagedAttention.pad_context_forward( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask + ) else: - attn_output = PagedAttention.pad_decoding_forward( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask - ) + if HAS_TRITON: + copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size) + else: + attn_output = PagedAttention.pad_decoding_forward( + query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask + ) attn_output = attn_output.view(bsz, q_len, self.num_heads, self.head_dim) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -175,7 +219,18 @@ def llama_attn_forward( return attn_output +@torch.no_grad() def generate_padding_position_id(attention_mask: torch.Tensor) -> torch.Tensor: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) return position_ids + + +@torch.no_grad() +def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_mask: torch.Tensor): + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + batch_size, kv_seq_len, num_key_value_heads, head_dim = q.shape + q = index_first_axis(q.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + k = index_first_axis(k.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + v = index_first_axis(v.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices) + return (q, k, v, indices) diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index c6552c339..54560d046 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -332,12 +332,20 @@ class BatchInfo: return torch.tensor(len_list, dtype=torch.int, device=self.device) def get_attn_mask(self, padding_id: int) -> torch.Tensor: + """ + Generate and return attention mask. + """ past_values = [] for seq in self.sequences_set: past_values.append(seq.input_token_id + seq.output_token_id) - return torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long() + attn_mask = torch.tensor(past_values, dtype=torch.int, device=self.device).ne(padding_id).long() + + if torch.any(attn_mask == 0): + return attn_mask + else: + return None def __repr__(self) -> str: return f"(sequences_set={self.sequences_set}, " f"is_prompts={self.is_prompts})" diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index 9a26098b3..2b3733c61 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -1,13 +1,16 @@ import argparse import time +from contextlib import nullcontext import torch import torch.distributed as dist import transformers +from transformers import AutoTokenizer, GenerationConfig import colossalai import colossalai.utils.device as device_utils -from colossalai.inference import InferenceEngine +from colossalai.inference.config import InferenceConfig +from colossalai.inference.core.engine import InferenceEngine from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.utils.device import get_current_device @@ -53,36 +56,14 @@ CONFIG_MAP = { def data_gen(batch_size: int = 4, seq_len: int = 512): input_ids = torch.randint(10, 30000, (batch_size, seq_len), device=get_current_device()) - attention_mask = torch.ones_like(input_ids) - data = dict(input_ids=input_ids, attention_mask=attention_mask) - return data + return input_ids -def print_details_info(outputs, model_config, args, whole_end2end): +def print_details_info(model_config, args, whole_end2end): msg: str = "" if dist.get_rank() == 0: msg += "-------Perf Summary-------\n" - if args.verbose: - timestamps = outputs[1] - prefill = [] - encoder = [] - end2end = [] - for timestamp in timestamps: - prefill.append(timestamp[1] - timestamp[0]) - encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) - ) - end2end.append(timestamp[-1] - timestamp[0]) - - mb_avg_end2end = sum(end2end) / len(end2end) - mb_avg_latency = mb_avg_end2end / (args.output_len * args.mb_size) - - msg += f"Average prefill time: {sum(prefill) / len(prefill) * 1000:.2f} ms\n" - msg += f"Average encode time: {sum(encoder) / len(encoder) * 1000:.2f} ms\n" - msg += f"Average micro batch end2end time: {mb_avg_end2end * 1000:.2f} ms\n" - msg += f"Average micro batch per token latency: {mb_avg_latency * 1000:.2f} ms\n" - whole_avg_latency = whole_end2end / (args.output_len * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size @@ -105,35 +86,87 @@ def print_details_info(outputs, model_config, args, whole_end2end): def benchmark_inference(args): - config = CONFIG_MAP[args.model] - model = transformers.LlamaForCausalLM(config) - if dist.get_rank() == 0: - print("Model loaded") - engine = InferenceEngine( - pp_size=args.pp_size, - tp_size=args.tp_size, - dtype=args.dtype, - micro_batch_size=args.mb_size, - model=model, - verbose=args.verbose, - max_batch_size=args.batch_size, - max_input_len=args.seq_len, - max_output_len=args.output_len, - ) - data = data_gen(args.batch_size, args.seq_len) + with torch.no_grad(): + config = CONFIG_MAP[args.model] + config.pad_token_id = config.eos_token_id + model = transformers.LlamaForCausalLM(config).cuda() + model = model.eval() + tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/") - N_WARMUP_STEPS = 2 + if args.dtype == "fp16": + model = model.half() + elif args.dtype == "bf16": + model = model.to(torch.bfloat16) - for _ in range(N_WARMUP_STEPS): - engine.generate(data) + # mbsz = args.mbsz + mbsz = args.batch_size + if args.mode == "caiinference": + inference_config = InferenceConfig( + dtype=args.dtype, + micro_batch_size=args.mb_size, + max_batch_size=mbsz, + max_input_len=args.seq_len, + max_output_len=args.output_len, + prefill_ratio=1.2, + ) + engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) + else: + engine = model - torch.cuda.synchronize() - whole_end2end = time.time() - outputs = engine.generate(data) - torch.cuda.synchronize() - whole_end2end = time.time() - whole_end2end + data = data_gen(mbsz, args.seq_len) + generation_config = GenerationConfig( + pad_token_id=tokenizer.pad_token_id, + max_new_tokens=args.output_len, + ) - print_details_info(outputs, model.config, args, whole_end2end) + N_WARMUP_STEPS = 2 + + ctx = ( + torch.profiler.profile( + record_shapes=True, + with_stack=True, + with_modules=True, + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=0, warmup=N_WARMUP_STEPS, active=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler("./tb_log_" + args.mode), + ) + if args.profile + else nullcontext() + ) + + with ctx: + for _ in range(N_WARMUP_STEPS): + if args.mode == "caiinference": + engine.add_request(prompts_token_ids=data) + engine.generate(generation_config) + else: + engine.generate(data, generation_config=generation_config) + if args.profile: + ctx.step() + + if args.nsys: + torch.cuda.cudart().cudaProfilerStart() + + torch.cuda.synchronize() + + whole_end2end = time.perf_counter() + if args.mode == "caiinference": + for _ in range(args.batch_size // mbsz): + engine.add_request(prompts_token_ids=data) + engine.generate(generation_config) + else: + for _ in range(args.batch_size // mbsz): + engine.generate(data, generation_config=generation_config) + whole_end2end = time.perf_counter() - whole_end2end + if args.nsys: + torch.cuda.cudart().cudaProfilerStop() + if args.profile: + ctx.step() + + print_details_info(model.config, args, whole_end2end) def hybrid_inference(rank, world_size, port, args): @@ -157,12 +190,21 @@ if __name__ == "__main__": choices=["toy", "llama-7b", "llama-13b", "llama2-7b", "llama2-13b"], ) parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") - parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--mbsz", type=int, default=8, help="batch size for one step") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="input sequence length") parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") parser.add_argument("--pp_size", type=int, default=1, help="pipeline size") parser.add_argument("--tp_size", type=int, default=1, help="pipeline size") parser.add_argument("--output_len", type=int, default=128, help="Output length") - parser.add_argument("--dtype", type=str, default="fp16", help="data type") + parser.add_argument("--dtype", type=str, default="fp16", help="data type", choices=["fp16", "fp32", "bf16"]) parser.add_argument("-v", "--verbose", default=False, action="store_true") + parser.add_argument("--profile", default=False, action="store_true", help="enable torch profiler") + parser.add_argument("--nsys", default=False, action="store_true", help="enable nsys profiler") + parser.add_argument( + "--mode", + default="caiinference", + choices=["caiinference", "transformers"], + help="decide which inference framework to run", + ) args = parser.parse_args() benchmark(args) diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index 394222ea6..294bba7da 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -1,15 +1,33 @@ ROOT=$(realpath $(dirname $0)) PY_SCRIPT=${ROOT}/benchmark_llama.py GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) +mode=$1 mkdir -p logs +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { + local n=${1:-"9999"} + echo "GPU Memory Usage:" + local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \ + | tail -n +2 \ + | nl -v 0 \ + | tee /dev/tty \ + | sort -g -k 2 \ + | awk '{print $1}' \ + | head -n $n) + export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g') + echo "Now CUDA_VISIBLE_DEVICES is set to:" + echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" +} + +CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 + # benchmark llama2-7b one single GPU for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 | tee logs/${GPU}_${bsz}_256.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt done -for bsz in 4 8 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 | tee logs/${GPU}_${bsz}_1024.txt +for bsz in 16 32 64; do + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt done