diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 4e429f7b8..aad0310cb 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -88,7 +88,7 @@ class InferenceConfig: use_cuda_kernel(bool): Whether to use cuda kernel, faster but lose some precision occasionally use_cuda_graph (bool): Whether to enforce CUDA graph execution. If False, we will disable CUDA graph and always execute the model in eager mode. If True, we will use eager execution in hybrid. max_context_len_to_capture (int): max context len that could be captured by CUDA Graph, per sequence - + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -122,6 +122,7 @@ class InferenceConfig: pp_size: int = 1 micro_batch_size: int = 1 micro_batch_buffer_size: int = None + high_precision: Optional[bool] = False # cuda kernel option use_cuda_kernel: bool = False @@ -149,6 +150,10 @@ class InferenceConfig: self.dtype in _ALLOWED_DTYPES ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" + # skip using casting when the data type is float32 + if self.dtype == torch.float32: + self.high_precision = False + # check distributed assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( self.tp_size * self.pp_size == dist.get_world_size() diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index e7bd1add7..a2388121b 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -61,6 +61,7 @@ class InferenceEngine: self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token self.generation_config = inference_config.to_generation_config(self.model_config) + self.high_precision = inference_config.high_precision model = model.eval() model = model.cuda() model.to(self.dtype) @@ -150,8 +151,10 @@ class InferenceEngine: batch_size=batch_size, is_prompts=False, use_cuda_graph=True, + high_precision=False, kv_seq_len=sequence_lengths[:batch_size].max().item(), head_dim=head_dim, + dtype=self.dtype, ) graph_runner = CUDAGraphRunner(self.model) @@ -391,8 +394,10 @@ class InferenceEngine: is_prompts=batch.is_prompts, use_cuda_kernel=self.inference_config.use_cuda_kernel, use_cuda_graph=use_cuda_graph, + high_precision=self.high_precision, kv_seq_len=sequence_lengths.max().item(), head_dim=batch.head_dim, + dtype=batch.dtype, ) return input_ids, output_tensor, input_meta_data @@ -421,7 +426,6 @@ class InferenceEngine: # TODO: padding_id is used for generating attn_mask and will be removed if nopad version is supported. logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) - if self.inference_config.pad_input: logits = logits[:, -1, :] self.request_handler.search_tokens(self.generation_config, logits) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index ccb2e837d..37a714c83 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -2,6 +2,7 @@ from typing import List, Optional, Tuple import torch +import torch.nn.functional as F from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -30,10 +31,12 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) try: - HAS_TRITON = True + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True except ImportError: - HAS_TRITON = False - logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") def llama_causal_lm_forward( @@ -47,9 +50,10 @@ def llama_causal_lm_forward( """This function will replace the forward function of LlamaForCausalLM. Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + batch (BatchInfo): It stores the necessary input information for this inference. + k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. + v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -61,6 +65,7 @@ def llama_causal_lm_forward( k_caches=k_caches, v_caches=v_caches, use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could + high_precision=inputmetadata.high_precision, ) logits = torch.mm(hidden_states, self.lm_head.weight) return logits @@ -74,13 +79,15 @@ def llama_model_forward( k_caches: List[torch.Tensor] = None, v_caches: List[torch.Tensor] = None, use_cuda_kernel: Optional[bool] = True, + high_precision: bool = False, ) -> torch.Tensor: """This function will replace the forward function of LlamaModel. Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + batch (BatchInfo): It stores the necessary input information for this inference. + k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. + v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ block_tables = inputmetadata.block_tables sequence_lengths = inputmetadata.sequence_lengths @@ -94,6 +101,10 @@ def llama_model_forward( use_cuda_kernel = False hidden_states = self.embed_tokens(input_tokens_ids) + if use_cuda_kernel and inputmetadata != torch.float32 and use_flash_attn2: + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + else: + cu_seqlens = None cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, inputmetadata.is_prompts) @@ -111,13 +122,15 @@ def llama_model_forward( v_cache=v_caches[layer_id], is_prompts=inputmetadata.is_prompts, sequence_lengths=sequence_lengths, - kv_seq_len=kv_seq_len, cos_sin=cos_sin, fd_inter_tensor=inputmetadata.fd_inter_tensor, + kv_seq_len=kv_seq_len, output_tensor=output_tensor, norm_output=norm_output, sm_scale=sm_scale, use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, ) if inputmetadata.is_prompts: @@ -134,38 +147,42 @@ def llama_decoder_layer_forward( self: LlamaDecoderLayer, hidden_states: torch.Tensor, residual: torch.Tensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, norm_output: torch.Tensor = None, sm_scale: int = None, use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """This function will replace the forward function of LlamaDecoderLayer. Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor]): Holding cos and sin. + fd_inter_tensor (FDIntermTensors): Holding tensors used for + storing intermediate values in flash-decoding. is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. - fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for - storing intermediate values in flash-decoding. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) @@ -175,14 +192,16 @@ def llama_decoder_layer_forward( block_tables=block_tables, k_cache=k_cache, v_cache=v_cache, - is_prompts=is_prompts, sequence_lengths=sequence_lengths, - kv_seq_len=kv_seq_len, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, + is_prompts=is_prompts, + kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, ) # Fully Connected @@ -276,43 +295,48 @@ class NopadLlamaAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, sm_scale: int = None, use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. - kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for - storing intermediate values in flash-decoding. Defaults to None. + storing intermediate values in flash-decoding. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ + token_nums = hidden_states.size(0) + if self.num_heads != self.num_key_value_heads: query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) else: # fused qkv - token_nums = hidden_states.size(0) hidden_states = hidden_states.expand(3, -1, -1) query_states, key_states, value_states = ( torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) @@ -321,23 +345,41 @@ class NopadLlamaAttention(LlamaAttention): block_size = k_cache.size(-2) if is_prompts: - if use_cuda_kernel: - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) else: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - ) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) else: if use_cuda_kernel: inference_ops.rotary_embedding_and_cache_copy( @@ -350,6 +392,7 @@ class NopadLlamaAttention(LlamaAttention): v_cache, sequence_lengths, block_tables, + high_precision, ) else: decoding_fused_rotary_embedding( @@ -435,6 +478,5 @@ class NopadLlamaMLP(LlamaMLP): """ hidden_states = hidden_states.expand(2, -1, -1) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) - act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) - tmp_out = act_out * gate_up_proj_out[1] - return torch.mm(tmp_out, self.down_proj_weight) + act_out = inference_ops.silu_and_mul(gate_up_proj_out) + return torch.mm(act_out, self.down_proj_weight) diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index a6cbf2ee1..448a84c6f 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -136,7 +136,8 @@ def benchmark_inference(args): data = data_gen(mbsz, args.seq_len) - data = data.tolist() + if args.mode == "colossalai" or args.mode == "vllm": + data = data.tolist() generation_config = GenerationConfig( pad_token_id=tokenizer.pad_token_id, diff --git a/extensions/csrc/common/micros.h b/extensions/csrc/common/micros.h index c2241029f..5400a6dc1 100644 --- a/extensions/csrc/common/micros.h +++ b/extensions/csrc/common/micros.h @@ -56,6 +56,23 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ + TYPE, NAME, ...) \ + switch (HIGH_PRECISION) { \ + case false: { \ + const bool high_precision = false; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + break; \ + } \ + case true: { \ + const bool high_precision = true; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + break; \ + } \ + default: \ + AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \ + } + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ case at::ScalarType::Float: { \ diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 2a767620a..77de7c12a 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -27,5 +27,18 @@ struct MPTypeTrait { using Type = float; }; +template +struct ScalarTypeTrait; + +template +struct ScalarTypeTrait { + using Type = typename MPTypeTrait::Type; +}; + +template +struct ScalarTypeTrait { + using Type = T; +}; + } // namespace common } // namespace colossalAI diff --git a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000..3f6adc018 --- /dev/null +++ b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -0,0 +1,195 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" + +template +__global__ void context_kv_cache_memcpy_kernel( + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cu_seqlens, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int batch_size, + const int block_table_stride, + const int64_t key_stride, + const int64_t value_stride +) +{ + const int seq_token_id = blockIdx.x; + const int seq_id = blockIdx.y; + const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; + + if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { + return ; + } + + const int block_offset = seq_token_id % block_size; + const int hidden_size = head_num * head_dim; + const int total_token_id = cu_seqlens[seq_id] + seq_token_id; + int head_id; + int head_offset; + int64_t key_src_id; + int64_t value_src_id; + int64_t target_id; + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); + } + + // tail process + for (; i < hidden_size; ++i ) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + +} + +template +void apply_context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(2); + int batch_size = block_tables.size(0); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + if (head_dim % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + int thread_nums = head_num * head_dim / vec_size; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(max_seq_len_in_batch, batch_size); + dim3 block(std::min(thread_nums, 512)); + + switch (vec_size) { + case 1: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + case 2: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + case 4: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "context_kv_cache_memcpy", + apply_context_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + cu_seqlens, + block_tables, + max_seq_len_in_batch + );) +} diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 3b1197a91..08889b236 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel( return ; } - for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) { + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { const int head_id = i / head_dim; const int head_offset = i % head_dim; const int64_t key_src_id = seq_id * key_stride + i; @@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel( copy_vector(value_cache + target_id, value + value_src_id); } + for (; i < hidden_size; ++i ) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + } template diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index 697dc7110..8feb6b343 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -1,14 +1,15 @@ - +// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision #include #include #include "utils/vector_copy_utils.h" #include "../common/micros.h" +#include "../common/mp_type_traits.h" -template +template __device__ void apply_emb_rotary_compute( - scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, const int64_t stride, + scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, + const m_scalar_t* __restrict__ sin_ptr, const int64_t stride, const int token_id, const int shard_block_size, const int half_head_dim, const int head_num, const int head_dim) { scalar_t x[VecSize]; @@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute( #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - - y[j] * sin_ptr[j * 32 + shard_offset]; - out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + - x[j] * sin_ptr[j * 32 + shard_offset]; + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); } copy_vector(src + addr_offset, out_x); @@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy( } } -template +template __device__ void cos_sin_memory_access( const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, - scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id, + m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id, const int shard_block_size, const int cos_stride, const int sin_stride, const int half_head_dim) { for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { @@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access( const int shard_offset = (i % shard_block_size) / VecSize; const int shard_head = (i / shard_block_size) * shard_block_size + i % VecSize * 32; - cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i]; - sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i]; + cos_ptr[shard_head + shard_offset] = static_cast(cos[token_id * cos_stride + i]); + sin_ptr[shard_head + shard_offset] = static_cast(sin[token_id * sin_stride + i]); } } -template +template __device__ void apply_k_rotary_emb_compute( scalar_t* __restrict__ key, scalar_t* __restrict__ value, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, + const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, const int64_t key_stride, const int64_t value_stride, const int token_id, @@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute( #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - - y[j] * sin_ptr[j * 32 + shard_offset]; - out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + - x[j] * sin_ptr[j * 32 + shard_offset]; + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); } copy_vector(key_cache + target_id, out_x); @@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute( block_size, block_offset, head_dim, half_head_dim); } -template +template __global__ void rotary_embedding_and_cache_copy_kernel( scalar_t* __restrict__ query, scalar_t* __restrict__ key, @@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel( extern __shared__ char shard_ptr[]; - scalar_t *cos_ptr = (scalar_t*)shard_ptr; - scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key and copy kv - apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); } -template +template __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, scalar_t* __restrict__ key, @@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel( extern __shared__ char shard_ptr[]; - scalar_t *cos_ptr = (scalar_t*)shard_ptr; - scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key - apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); } -template +template void apply_rotary_embedding_and_cache_copy( at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, kv_head_num, head_dim] @@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy( int sin_stride = sin.stride(0); int block_table_stride = block_tables.stride(0); + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { @@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy( switch (vec_size) { case 1: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy( ); break; case 2: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy( ); break; case 4: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -338,12 +341,12 @@ void apply_rotary_embedding_and_cache_copy( AT_CUDA_CHECK(cudaGetLastError()); } -template +template void apply_rotary_embedding( at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& cos, // [total_tokens, head_dim] - at::Tensor& sin // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] ){ int num_tokens = query.size(0); int head_num = query.size(1); @@ -355,6 +358,8 @@ void apply_rotary_embedding( int cos_stride = cos.stride(0); int sin_stride = sin.stride(0); + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { @@ -373,7 +378,7 @@ void apply_rotary_embedding( switch (vec_size) { case 1: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -389,7 +394,7 @@ void apply_rotary_embedding( ); break; case 2: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -405,7 +410,7 @@ void apply_rotary_embedding( ); break; case 4: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy( at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] - at::Tensor& block_tables) // [batch_size, max_seq_len] + at::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, query.scalar_type(), "rotary_embedding_and_cache_copy", - apply_rotary_embedding_and_cache_copy( + apply_rotary_embedding_and_cache_copy( query, key, value, @@ -458,12 +465,14 @@ void rotary_embedding( at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& cos, // [total_tokens, head_dim] - at::Tensor& sin // [total_tokens, head_dim] + at::Tensor& sin, // [total_tokens, head_dim] + bool high_precision ){ - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, query.scalar_type(), "rotary_embedding", - apply_rotary_embedding( + apply_rotary_embedding( query, key, cos, diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 4282f5382..541146e3a 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -9,11 +9,22 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch); + void rotary_embedding( torch::Tensor& query, // [total_tokens, head_num, head_dim] torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] torch::Tensor& cos, // [total_tokens, head_dim] - torch::Tensor& sin); // [total_tokens, head_dim] + torch::Tensor& sin, // [total_tokens, head_dim] + bool high_precision); void rotary_embedding_and_cache_copy( torch::Tensor& query, // [num_tokens, head_num, head_dim] @@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy( torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_dim] torch::Tensor& sequence_lengths, // [batch_size] - torch::Tensor& block_tables); // [batch_size, max_seq_len] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision); + torch::Tensor silu_and_mul(const torch::Tensor& ins); void rms_layernorm(torch::Tensor& out, // [..., hidden_size] @@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the context stage."); + m.def( "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, "performing Rotary Embedding-related calculations and KVCache Memcopy."); diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h index 524ef46c6..bd2465bea 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h @@ -11,6 +11,8 @@ #include #include +#include "utils/vector_copy_utils.h" + namespace { int log2_ceil(int value) { diff --git a/extensions/csrc/cuda/utils/vector_copy_utils.h b/extensions/csrc/cuda/utils/vector_copy_utils.h index 3c3afa0b3..5157ec738 100644 --- a/extensions/csrc/cuda/utils/vector_copy_utils.h +++ b/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -11,16 +11,16 @@ template __device__ __inline__ void copy_vector(T *dst, const T *src) { using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; // Note(LiuYang): Here static_cast can't be used for cast between two pointer - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } template <> __device__ __inline__ void copy_vector(float *dst, const float *src) { // Since the maximum memory alignment length is 128 bits, we choose float4 // here. - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); *(reinterpret_cast(dst + 4)) = - *(reinterpret_cast(src + 4)); + *(reinterpret_cast(src + 4)); } template diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index ae3754ca7..4e0afc819 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/pybind/inference.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/context_kv_cache_memcpy_kernel.cu", "cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py index d5259a596..3fa17037f 100644 --- a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py @@ -1,8 +1,10 @@ import pytest import torch +import torch.nn.functional as F from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2 from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data inference_ops = InferenceOpsLoader().load() @@ -10,12 +12,7 @@ inference_ops = InferenceOpsLoader().load() HEAD_DIM = 4 -@pytest.mark.parametrize("bsz", [4, 7, 32]) -@pytest.mark.parametrize("block_size", [16, 32, 64]) -@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) -@pytest.mark.parametrize("num_kv_heads", [16]) -@pytest.mark.parametrize("same_context_len", [True, False]) -def test_copy_kv_to_caches( +def run_decode_copy_kv_to_caches( bsz: int, block_size: int, max_num_blocks_per_seq: int, @@ -61,5 +58,65 @@ def test_copy_kv_to_caches( assert torch.equal(v_target, v_source) +def run_context_copy_kv_to_cache( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + max_seq_len = max_num_blocks_per_seq * block_size + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + + num_tokens = torch.sum(context_lengths).item() + + max_seq_len_in_batch = context_lengths.max() + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + kv_size = (num_tokens, num_kv_heads, HEAD_DIM) + key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( + key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + + block_tables = block_tables.to(device=device) + k_cache = torch.zeros_like(k_cache_ref) + v_cache = torch.zeros_like(v_cache_ref) + + inference_ops.context_kv_cache_memcpy( + key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch + ) + + assert torch.equal(k_cache, k_cache_ref) + assert torch.equal(v_cache, v_cache_ref) + + +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_kv_cache_memcopy( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + + if __name__ == "__main__": - test_copy_kv_to_caches(4, 32, 8, 16, True) + test_kv_cache_memcopy(4, 32, 8, 16, True) diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py index b9c0a3269..9e0a8b0db 100644 --- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb @@ -10,11 +11,18 @@ from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("SEQ_LEN", [64]) @pytest.mark.parametrize("H", [32]) @pytest.mark.parametrize("D", [64]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN @@ -54,17 +62,36 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") - q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) new_q_copy = new_q.clone() new_k_copy = new_k.clone() + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + new_q_fp16 = new_q.clone() + new_k_fp16 = new_k.clone() + + high_precision_cos = cos[:BATCH_SIZE].to(torch.float32) + high_precision_sin = sin[:BATCH_SIZE].to(torch.float32) + high_precision_q = new_q.to(torch.float32) + high_precision_k = new_k.to(torch.float32) + q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16) + k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + inference_ops.rotary_embedding_and_cache_copy( - new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True ) - inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin) + inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True) past_kv_seq_len = kv_seq_lengths - 1 target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] @@ -74,18 +101,26 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() v_source = new_v.squeeze() - assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6) - assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6) + numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol) + numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol) - assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6) - assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6) + numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol) assert k_target.shape == k_source.shape - assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6) + numpy_allclose(k_target, k_source, rtol=rtol, atol=atol) assert v_target.shape == v_source.shape assert torch.equal(v_target, v_source) + if dtype == torch.float16: + # After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test. + rtol = 1e-3 + atol = 1e-1 + inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False) + numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol) + if __name__ == "__main__": - test_rotary_emb(16, 512, 4, 128, torch.float16) + test_rotary_emb(16, 64, 4, 128, torch.float16)