diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 2c77a6e12..5014821d0 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -55,6 +55,7 @@ class InferenceConfig: def __post_init__(self): self._init_batch_size() self._verify_config() + self._get_dtype() def _init_batch_size(self): """ @@ -84,6 +85,7 @@ class InferenceConfig: assert ( self.tp_size * self.pp_size == dist.get_world_size() ), f"TP size({self.tp_size}) * PP size({self.pp_size}) should be equal to the global world size ({dist.get_world_size()})" + assert self.dtype in [ "fp16", "fp32", @@ -97,3 +99,11 @@ class InferenceConfig: "gptq", None, ], f"quant should be one of 'smoothquant', 'gptq', but got {self.quant_mode}." + + def _get_dtype(self) -> None: + if self.dtype == "fp32" or self.dtype == torch.float32: + self.dtype = torch.float32 + elif self.dtype == "fp16" or self.dtype == torch.float16: + self.dtype = torch.float16 + else: + self.dtype = torch.bfloat16 diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index c62094f9c..9c49a60a0 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -51,17 +51,10 @@ class InferenceEngine: self.inference_config = inference_config self.model_config = model.config self.device = torch.device("cuda") + self.dtype = inference_config.dtype model = model.eval() - - if inference_config.dtype == "fp32" or inference_config.dtype == torch.float32: - self.dtype = torch.float32 - elif inference_config.dtype == "fp16" or inference_config.dtype == torch.float16: - self.dtype = torch.float16 - model.half() - else: - self.dtype = torch.bfloat16 - model.to(torch.bfloat16) + model.to(self.dtype) if model_policy is None: model_policy = model_policy_map[self.model_config.model_type]() @@ -217,6 +210,7 @@ class InferenceEngine: None, block_table, self.tokenizer.eos_token_id, + self.tokenizer.pad_token_id, self.inference_config.max_output_len, ) self.request_handler.add_sequence(sequence) @@ -241,7 +235,6 @@ class InferenceEngine: batch, self.k_cahce, self.v_cache, - padding_id=self.tokenizer.pad_token_id, ) logits = logits[:, -1, :] diff --git a/colossalai/inference/core/request_handler.py b/colossalai/inference/core/request_handler.py index 730a358cd..585f87945 100644 --- a/colossalai/inference/core/request_handler.py +++ b/colossalai/inference/core/request_handler.py @@ -4,6 +4,7 @@ import torch from transformers.configuration_utils import PretrainedConfig from colossalai.inference.config import InferenceConfig +from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.kv_cache import KVCacheManager from colossalai.inference.logit_processors import logit_processor from colossalai.inference.sampler import * @@ -69,20 +70,60 @@ class RequestHandler: Args: inference_config: Configuration for initialize and manage kv cache. model_config: Configuration for model + dtype (torch.dtype): The data type for weights and activations. """ def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: self.inference_config = inference_config - self._init_cache(model_config) - self.running_list: RunningList = RunningList(inference_config.prefill_ratio) self.waiting_list: List[List] = [[], [], []] self.done_list: List[Sequence] = [] - device = torch.cuda.current_device() - self.running_batch = BatchInfo(is_prompts=False, device=device) - self.prefill_batch = BatchInfo(is_prompts=True, device=device) + self.dtype = inference_config.dtype self.max_batch_size = inference_config.max_batch_size + # initialize cache + self._init_cache(model_config) + + # initialize batch + device = torch.cuda.current_device() + kv_max_split_num = ( + inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 + ) // inference_config.block_size + head_dim = model_config.hidden_size // model_config.num_attention_heads + + fd_inter_tensor = FDIntermTensors() + fd_inter_tensor.initialize( + max_batch_size=self.max_batch_size, + num_attn_heads=model_config.num_attention_heads, + kv_max_split_num=kv_max_split_num, + head_dim=head_dim, + dtype=self.dtype, + device=device, + ) + + # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, + # which may cause bugs and this issue should be fixed later. + self.running_batch = BatchInfo( + max_batch_size=self.max_batch_size, + kv_max_split_num=kv_max_split_num, + num_heads=model_config.num_attention_heads, + head_dim=head_dim, + is_prompts=False, + device=device, + dtype=self.dtype, + fd_inter_tensor=fd_inter_tensor, + ) + self.prefill_batch = BatchInfo( + max_batch_size=self.max_batch_size, + kv_max_split_num=kv_max_split_num, + num_heads=model_config.num_attention_heads, + head_dim=head_dim, + is_prompts=True, + device=device, + dtype=self.dtype, + fd_inter_tensor=fd_inter_tensor, + ) + def _init_cache(self, model_config): self.cache_manager = KVCacheManager(self.inference_config, model_config) diff --git a/colossalai/kernel/triton/flash_decoding_utils.py b/colossalai/inference/flash_decoding_utils.py similarity index 100% rename from colossalai/kernel/triton/flash_decoding_utils.py rename to colossalai/inference/flash_decoding_utils.py diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 3a1e31c8d..5bcc3e35f 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -58,12 +58,7 @@ class KVCacheManager: # Parallel settings self.tp_size = config.tp_size # Model settings - if config.dtype == "fp32" or config.dtype == torch.float32: - self.dtype = torch.float32 - elif config.dtype == "fp16" or config.dtype == torch.float16: - self.dtype = torch.float16 - else: - self.dtype = torch.bfloat16 + self.dtype = config.dtype self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") # For now we focus on MHA only, TODO add handling for MQA and GQA diff --git a/colossalai/inference/modeling/models/llama.py b/colossalai/inference/modeling/models/llama.py index ffd7d2292..3e3890545 100644 --- a/colossalai/inference/modeling/models/llama.py +++ b/colossalai/inference/modeling/models/llama.py @@ -4,6 +4,7 @@ from typing import List, Optional, Tuple import torch from transformers.models.llama.modeling_llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel +from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.inference.modeling.layers.attention import PagedAttention from colossalai.inference.struct import BatchInfo from colossalai.kernel.triton import ( @@ -50,7 +51,6 @@ def llama_causal_lm_forward( batch: BatchInfo = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, - padding_id: int = None, ): # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) hidden_states = llama_model_forward( @@ -58,7 +58,6 @@ def llama_causal_lm_forward( batch=batch, k_caches=k_caches, v_caches=v_caches, - padding_id=padding_id, ) logits = self.lm_head(hidden_states) return logits @@ -70,11 +69,10 @@ def llama_model_forward( batch: BatchInfo = None, k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, - padding_id: int = None, ): input_ids = batch.get_batch_inputs() block_tables = batch.get_block_table_tensor() - attention_mask = batch.get_attn_mask(padding_id) + attention_mask = batch.get_attn_mask() if attention_mask is not None: if HAS_TRITON: @@ -84,6 +82,7 @@ def llama_model_forward( else: sequence_lengths = batch.get_sequence_lengths() + batch_size, _ = input_ids.shape kv_seq_len = sequence_lengths.max().item() if attention_mask is not None: @@ -102,7 +101,22 @@ def llama_model_forward( hidden_states = self.embed_tokens(input_ids) - cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, hidden_states.dtype) + # When testing, the performance of get_xine_cache is lower than that of get_cos_sin. + # cos = get_xine_cache(sequence_lengths, self._cos_cached, batch.is_prompts) + # sin = get_xine_cache(sequence_lengths, self._sin_cached, batch.is_prompts) + # cos_sin = (cos, sin) + + cos_sin = get_cos_sin(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts, batch.dtype) + + if batch.is_prompts: + output_tensor = torch.zeros( + (sequence_lengths.sum().item(), batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + else: + output_tensor = torch.zeros( + (batch_size, 1, batch.num_heads, batch.head_dim), dtype=batch.dtype, device=batch.device + ) + sm_scale = 1.0 / (batch.head_dim**0.5) for layer_id, decoder_layer in enumerate(self.layers): hidden_states = decoder_layer( @@ -116,6 +130,9 @@ def llama_model_forward( attention_mask=attention_mask, kv_seq_len=kv_seq_len, cos_sin=cos_sin, + fd_inter_tensor=batch.fd_inter_tensor, + output_tensor=output_tensor, + sm_scale=sm_scale, ) hidden_states = self.norm(hidden_states) @@ -131,10 +148,13 @@ def llama_decoder_layer_forward( k_cache: torch.Tensor = None, v_cache: torch.Tensor = None, is_prompts: bool = True, - sequence_lengths: int = None, + sequence_lengths: torch.Tensor = None, attention_mask: torch.Tensor = None, kv_seq_len: int = 0, cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states @@ -151,6 +171,9 @@ def llama_decoder_layer_forward( attention_mask=attention_mask, kv_seq_len=kv_seq_len, cos_sin=cos_sin, + fd_inter_tensor=fd_inter_tensor, + output_tensor=output_tensor, + sm_scale=sm_scale, ) hidden_states = residual + hidden_states @@ -178,6 +201,9 @@ def llama_attn_forward( attention_mask: torch.Tensor = None, kv_seq_len: int = 0, cos_sin: Tuple[torch.Tensor] = None, + fd_inter_tensor: FDIntermTensors = None, + output_tensor: torch.Tensor = None, + sm_scale: int = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -206,7 +232,17 @@ def llama_attn_forward( if is_prompts: attn_output = context_attention_unpadded( - query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, ) if attention_mask is not None: attn_output = pad_input(attn_output, indices, bsz, q_len) @@ -214,7 +250,17 @@ def llama_attn_forward( copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) attn_output = flash_decoding_attention( - query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + sm_scale=sm_scale, ) attn_output = attn_output.squeeze(1) else: @@ -285,6 +331,16 @@ def unpading_input(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attention_ @torch.no_grad() def get_cos_sin(lengths, cos_cache, sin_cache, is_prompts, dtype): + """ + Get cos and sin for the cache, and return nopad format. + Args: + lengths: shape(num_seqs,), stores lenghth of each sequence. + cos_cache: shape(max_rotary_position(e.g.2048), head_dim), cos cache constrcuted in model. + sin_cache: shape(max_rotary_position(e.g.2048), head_dim), sin cache constrcuted in model. + is_prompts: bool, mark if in prefill mode. + dtype: The data type of this inference process. + """ + if is_prompts: index_arrays = [torch.arange(length) for length in lengths] else: diff --git a/colossalai/inference/struct.py b/colossalai/inference/struct.py index 05ab72bf4..feb50da99 100644 --- a/colossalai/inference/struct.py +++ b/colossalai/inference/struct.py @@ -5,6 +5,7 @@ from typing import Any, List, Tuple, Union import torch from ordered_set import OrderedSet +from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.logging import get_dist_logger logger = get_dist_logger(__name__) @@ -61,6 +62,7 @@ class Sequence: sample_params (SampleParams): The sample_params of input sequence. block_table (torch.Tensor): The index of input sequence in block_table. eos_token_id (int): The eos token id for this inference process. + pad_token_id (int): The pad token id for this inference process. max_output_len (int): Maximum output length. """ @@ -71,6 +73,7 @@ class Sequence: sample_params: Any # SampleParams needs to be imported later. block_table: torch.Tensor eos_token_id: int + pad_token_id: int max_output_len: int = 256 def __post_init__(self): @@ -167,15 +170,23 @@ class BatchInfo: Information to be passed and used for a batch of sequences. """ + max_batch_size: int + kv_max_split_num: int + num_heads: int + head_dim: int sequences_set: OrderedSet[Sequence] = None is_prompts: bool = True device: torch.device = None + dtype: torch.dtype = None + fd_inter_tensor: FDIntermTensors = None def __post_init__(self): if self.device is None: self.device = torch.cuda.current_device() if self.sequences_set is None: self.sequences_set = OrderedSet() + if self.fd_inter_tensor is None: + self.fd_inter_tensor = FDIntermTensors() def init_batch(self, seqs: List["Sequence"] = None): """ @@ -185,8 +196,6 @@ class BatchInfo: seqs (List["Sequence"]): List of input sequence. """ - assert len(self.sequences_set) == 0, "Sequences set has been initialized." - if seqs is not None: if not isinstance(seqs, list): seqs = [seqs] @@ -197,16 +206,30 @@ class BatchInfo: self.sequences_set.add(seq) + def init_fd_tensors(self): + if not self.fd_inter_tensor.is_initialized: + self.fd_inter_tensor.initialize( + max_batch_size=self.max_batch_size, + num_attn_heads=self.num_heads, + kv_max_split_num=self.kv_max_split_num, + head_dim=self.head_dim, + dtype=self.dtype, + device=self.device, + ) + def get_block_table_tensor(self) -> None: tesnor_list = [] block_table = None + + assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." + for seq in self.sequences_set: block_table = seq.block_table assert ( block_table is not None ), f"The sequence(request_id {seq.request_id}) has not initialized the block_table." tesnor_list.append(seq.block_table) - assert tesnor_list, "Batch has not been initialized yet. Please initialize batch first." + block_table = torch.stack(tesnor_list) return block_table @@ -218,7 +241,6 @@ class BatchInfo: """ if self.is_prompts: self.sequences_set.clear() - else: for seq in self.sequences_set: seq.mark_aborted() @@ -312,14 +334,14 @@ class BatchInfo: """ Get bacth inputs for forward inference computation. """ + input_list = [] + assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." + for seq in self.sequences_set: if self.is_prompts: if seq.output_len > 0: - print(seq.output_token_id) - seq_data = seq.input_token_id + seq.output_token_id - print(seq_data) input_list.append(seq.input_token_id + seq.output_token_id) else: input_list.append(seq.input_token_id) @@ -328,7 +350,8 @@ class BatchInfo: max_seq_len = max(len(sub_list) for sub_list in input_list) - return _make_tensor_with_pad(input_list, max_seq_len, 0, dtype=torch.int) + # We assume that all the padding_id in seq are the same at present. + return _make_tensor_with_pad(input_list, max_seq_len, self.sequences_set[0].pad_token_id, dtype=torch.int) def get_1D_inputs(self) -> Tuple[torch.LongTensor, torch.Tensor]: """ @@ -336,6 +359,9 @@ class BatchInfo: """ input_list = [] input_len_list = [] + + assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." + for seq in self.sequences_set: if self.is_prompts: input_list.extend(seq.input_token_id) @@ -353,16 +379,23 @@ class BatchInfo: Get the input_len of each sentence in this batch. """ len_list = [] + + assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." + for seq in self.sequences_set: len_list.append(seq.sentence_len) return torch.tensor(len_list, dtype=torch.int, device=self.device) - def get_attn_mask(self, padding_id: int) -> torch.Tensor: + def get_attn_mask(self) -> torch.Tensor: """ Generate and return attention mask. """ + assert len(self.sequences_set) > 0, "Batch has not been initialized yet. Please initialize batch first." + past_values = [] + # We assume that all the padding_id in seq are the same at present. + padding_id = self.sequences_set[0].pad_token_id for seq in self.sequences_set: past_values.append(seq.input_token_id + seq.output_token_id) @@ -378,7 +411,7 @@ class BatchInfo: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len - return x + [pad] * (max_len - len(x)) + return [pad] * (max_len - len(x)) + x def _make_tensor_with_pad( diff --git a/colossalai/kernel/triton/__init__.py b/colossalai/kernel/triton/__init__.py index fb8b3339b..8715f9981 100644 --- a/colossalai/kernel/triton/__init__.py +++ b/colossalai/kernel/triton/__init__.py @@ -10,7 +10,6 @@ except ImportError: if HAS_TRITON: from .context_attn_unpad import context_attention_unpadded from .flash_decoding import flash_decoding_attention - from .flash_decoding_utils import FDIntermTensors from .fused_rotary_embedding import fused_rotary_embedding from .gptq_triton import gptq_fused_linear_triton from .kvcache_copy import copy_kv_to_blocked_cache @@ -27,7 +26,6 @@ if HAS_TRITON: "rms_layernorm", "gptq_fused_linear_triton", "rotary_embedding", - "FDIntermTensors", "fused_rotary_embedding", "get_xine_cache", ] diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index e31d9e5da..3ef43cb83 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -5,7 +5,6 @@ # # Inspired and modified from Triton Tutorial - Fused Attention # https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html -from typing import Optional import torch import triton @@ -195,7 +194,9 @@ def context_attention_unpadded( context_lengths: torch.Tensor, # [num_seqs] block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_size: int, - max_seq_len_in_b: Optional[int] = None, + output: torch.Tensor = None, # [num_tokens, num_heads, head_dim] + max_seq_len: int = None, + sm_scale: int = None, ): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk == Lv @@ -210,10 +211,9 @@ def context_attention_unpadded( num_kv_group = num_heads // num_kv_heads num_seqs, max_blocks_per_seq = block_tables.shape - max_seq_len = context_lengths.max().item() if max_seq_len_in_b is None else max_seq_len_in_b - sm_scale = 1.0 / (Lq**0.5) - - output = torch.zeros_like(q) + max_seq_len = context_lengths.max().item() if max_seq_len is None else max_seq_len + sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale + output = torch.zeros_like(q) if output is None else output # NOTE For now, BLOCK_M and BLOCK_N are supposed to be equivalent with # the size of physical cache block (i.e. `block_size`) diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index 0a42a2f13..6b3ed2999 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -195,6 +195,7 @@ def flash_decoding_attention( block_tables: torch.Tensor, block_size: int, max_seq_len_in_batch: int = None, + output: torch.Tensor = None, mid_output: torch.Tensor = None, mid_output_lse: torch.Tensor = None, sm_scale: int = None, @@ -211,6 +212,7 @@ def flash_decoding_attention( records the (kv) sequence lengths incorporating past kv sequence lengths. block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence] max_seq_len_in_batch (int): Maximum sequence length in the batch. + output (torch.Tensor): [bsz, 1, num_heads, head_dim] mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim] Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`. mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num] @@ -292,7 +294,7 @@ def flash_decoding_attention( HEAD_DIM=head_dim, ) - output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped + output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) if output is None else output grid = (triton.next_power_of_2(bsz), num_heads) diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index bcc426e3a..772fe2200 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -91,7 +91,7 @@ def benchmark_inference(args): config.pad_token_id = config.eos_token_id model = transformers.LlamaForCausalLM(config).cuda() model = model.eval() - tokenizer = AutoTokenizer.from_pretrained("/home/caidi/llama_model/") + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") if args.dtype == "fp16": model = model.half() diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index 294bba7da..bdd79836e 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -23,11 +23,12 @@ CUDA_VISIBLE_DEVICES_set_n_least_memory_usage() { CUDA_VISIBLE_DEVICES_set_n_least_memory_usage 1 # benchmark llama2-7b one single GPU + for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 256 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 512 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_256.txt done for bsz in 16 32 64; do - python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 128 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt + python3 ${PY_SCRIPT} -m llama2-7b --tp_size 1 --pp_size 1 -b $bsz -s 1024 --output_len 256 --mode ${mode} | tee logs/${mode}_${GPU}_${bsz}_1024.txt done diff --git a/tests/test_infer/test_config_and_struct.py b/tests/test_infer/test_config_and_struct.py index 348cd5d21..16f5bcc7f 100755 --- a/tests/test_infer/test_config_and_struct.py +++ b/tests/test_infer/test_config_and_struct.py @@ -17,6 +17,7 @@ def check_config_and_inference(): sample_params=None, block_table=None, eos_token_id=2, + pad_token_id=2, max_output_len=256, ) @@ -28,6 +29,7 @@ def check_config_and_inference(): sample_params=None, block_table=None, eos_token_id=2, + pad_token_id=2, max_output_len=256, ) @@ -39,6 +41,7 @@ def check_config_and_inference(): sample_params=None, block_table=None, eos_token_id=2, + pad_token_id=2, max_output_len=256, ) sequence.mark_running() @@ -51,7 +54,12 @@ def check_config_and_inference(): assert sequence.output_len == 0 assert sequence.check_finish() == False - batch = BatchInfo(is_prompts=False) + batch = BatchInfo( + max_batch_size=8, + kv_max_split_num=16, + num_heads=2, + head_dim=128, + ) batch.init_batch([sequence]) batch.add_seqs([sequence2, sequence3]) batch.add_seqs([sequence]) diff --git a/tests/test_infer/test_inference_engine.py b/tests/test_infer/test_inference_engine.py index 4e5d8c733..19e1a5636 100644 --- a/tests/test_infer/test_inference_engine.py +++ b/tests/test_infer/test_inference_engine.py @@ -3,8 +3,7 @@ import random import numpy as np import pytest import torch -import transformers -from transformers import AutoTokenizer, GenerationConfig +from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM import colossalai from colossalai.inference.config import InferenceConfig @@ -22,8 +21,8 @@ def setup_seed(seed): def check_inference_engine(test_cai=False): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") - model = transformers.LlamaForCausalLM( - transformers.LlamaConfig( + model = LlamaForCausalLM( + LlamaConfig( vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16 ) ).cuda() @@ -81,4 +80,4 @@ def test_inference_engine(): if __name__ == "__main__": - test_inference_engine() \ No newline at end of file + test_inference_engine() diff --git a/tests/test_infer/test_request_handler.py b/tests/test_infer/test_request_handler.py index 673fcf9cf..d589e9717 100644 --- a/tests/test_infer/test_request_handler.py +++ b/tests/test_infer/test_request_handler.py @@ -20,6 +20,7 @@ def check_running_list(): input_token_id=[1, 2, 3], block_size=16, eos_token_id=0, + pad_token_id=0, sample_params=None, block_table=1, ) @@ -56,6 +57,7 @@ def check_request_handler(): input_token_id=[1, 2, 3, 4, 5], block_size=16, eos_token_id=0, + pad_token_id=0, sample_params=None, block_table=torch.tensor([-1, -1]), ) diff --git a/tests/test_infer_ops/triton/test_decoding_attn.py b/tests/test_infer_ops/triton/test_decoding_attn.py index 063ae2814..8d1a5a36c 100644 --- a/tests/test_infer_ops/triton/test_decoding_attn.py +++ b/tests/test_infer_ops/triton/test_decoding_attn.py @@ -91,6 +91,7 @@ def test_flash_decoding( max_seq_len_in_b = kv_seq_lengths.max().item() # The maximum block length splitted on kv should be the kv cache block size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=q.dtype, device=q.device) mid_output = torch.empty( size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device ) @@ -106,6 +107,7 @@ def test_flash_decoding( block_tables, block_size, max_seq_len_in_b, + output, mid_output, mid_output_lse, sm_scale=sm_scale, @@ -184,6 +186,7 @@ def bench_kernel( block_tables = block_tables.to(device=device) # the maximum block length splitted on kv should be the kv cache block size kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size + output = torch.empty((bsz, 1, num_attn_heads, HEAD_DIM), dtype=dtype, device=device) mid_output = torch.empty( size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device ) @@ -199,6 +202,7 @@ def bench_kernel( block_tables, block_size, max_seq_len_in_b, + output, mid_output, mid_output_lse, sm_scale=sm_scale,