diff --git a/colossalai/inference/config.py b/colossalai/inference/config.py index 7ce4719e7..7b49e8f77 100644 --- a/colossalai/inference/config.py +++ b/colossalai/inference/config.py @@ -55,7 +55,7 @@ class InferenceConfig: pp_size (int): Pipeline parallel size, defaults to 1. micro_batch_size (int): the micro batch size, defaults to 1. Only useful when `pp_size` > 1. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ # NOTE: arrange configs according to their importance and frequency of usage @@ -89,6 +89,7 @@ class InferenceConfig: pp_size: int = 1 micro_batch_size: int = 1 micro_batch_buffer_size: int = None + high_precision: Optional[bool] = False def __post_init__(self): self._verify_config() @@ -108,6 +109,10 @@ class InferenceConfig: self.dtype in _ALLOWED_DTYPES ), f"Expected dtype to be in {_ALLOWED_DTYPES} but found an unknown dtype: {self.dtype}" + # skip using casting when the data type is float32 + if self.dtype == torch.float32: + self.high_precision = False + # check distributed assert (not torch.distributed.is_initialized() and self.tp_size * self.pp_size == 1) or ( self.tp_size * self.pp_size == dist.get_world_size() diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index 8c7829c02..4833e5b0c 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -56,6 +56,7 @@ class InferenceEngine: self.tokenizer = tokenizer self.tokenizer.pad_token = self.tokenizer.eos_token self.generation_config = inference_config.to_generation_config(self.model_config) + self.high_precision = inference_config.high_precision model = model.eval() model = model.cuda() model.to(self.dtype) @@ -297,6 +298,7 @@ class InferenceEngine: batch, self.k_cahce, self.v_cache, + self.high_precision, ) if self.inference_config.pad_input: diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 12de4802b..9ea79551e 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -2,6 +2,7 @@ from typing import List, Optional, Tuple import torch +import torch.nn.functional as F from transformers.models.llama.modeling_llama import ( LlamaAttention, LlamaConfig, @@ -30,24 +31,28 @@ inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) try: - HAS_TRITON = True + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True except ImportError: - HAS_TRITON = False - logger.warning(f"triton has not been installed yet, we will use torch to complete the attention calculation.") + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") def llama_causal_lm_forward( self: LlamaForCausalLM, - batch: BatchBucket = None, - k_caches: List[torch.Tensor] = None, - v_caches: List[torch.Tensor] = None, + batch: BatchBucket, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + high_precision: bool = False, ): """This function will replace the forward function of LlamaForCausalLM. Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + batch (BatchInfo): It stores the necessary input information for this inference. + k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. + v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -56,6 +61,7 @@ def llama_causal_lm_forward( batch=batch, k_caches=k_caches, v_caches=v_caches, + high_precision=high_precision, ) logits = torch.mm(hidden_states, self.lm_head.weight) return logits @@ -63,16 +69,18 @@ def llama_causal_lm_forward( def llama_model_forward( self: LlamaModel, - batch: BatchBucket = None, - k_caches: List[torch.Tensor] = None, - v_caches: List[torch.Tensor] = None, + batch: BatchBucket, + k_caches: List[torch.Tensor], + v_caches: List[torch.Tensor], + high_precision: bool = False, ): """This function will replace the forward function of LlamaModel. Args: - batch (BatchInfo, optional): It stores the necessary input information for this inference.. Defaults to None. - k_caches (List[torch.Tensor], optional): It holds the GPU memory for the key cache. Defaults to None. - v_caches (List[torch.Tensor], optional): It holds the GPU memory for the value cache. Defaults to None. + batch (BatchInfo): It stores the necessary input information for this inference. + k_caches (List[torch.Tensor]): It holds the GPU memory for the key cache. + v_caches (List[torch.Tensor]): It holds the GPU memory for the value cache. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ input_ids = batch.get_1D_inputs() block_tables = batch.get_block_table_tensor() @@ -86,6 +94,11 @@ def llama_model_forward( if batch_size >= 32 and kv_seq_len > 512: use_cuda_kernel = False + if use_cuda_kernel and batch.dtype != torch.float32 and use_flash_attn2: + cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + else: + cu_seqlens = None + hidden_states = self.embed_tokens(input_ids) cos_sin = get_xine_cache(sequence_lengths, self._cos_cached, self._sin_cached, batch.is_prompts) @@ -110,15 +123,17 @@ def llama_model_forward( block_tables=block_tables, k_cache=k_caches[layer_id], v_cache=v_caches[layer_id], - is_prompts=batch.is_prompts, sequence_lengths=sequence_lengths, - kv_seq_len=kv_seq_len, cos_sin=cos_sin, fd_inter_tensor=batch.fd_inter_tensor, + is_prompts=batch.is_prompts, + kv_seq_len=kv_seq_len, output_tensor=output_tensor, norm_output=norm_output, sm_scale=sm_scale, use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, ) if batch.is_prompts: @@ -135,38 +150,42 @@ def llama_decoder_layer_forward( self: LlamaDecoderLayer, hidden_states: torch.Tensor, residual: torch.Tensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, norm_output: torch.Tensor = None, sm_scale: int = None, use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """This function will replace the forward function of LlamaDecoderLayer. Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. residual (torch.Tensor): shape [token_num, embed_dim], used to be added to hidden_states in out_proj. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor]): Holding cos and sin. + fd_inter_tensor (FDIntermTensors): Holding tensors used for + storing intermediate values in flash-decoding. is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. - fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for - storing intermediate values in flash-decoding. Defaults to None. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. norm_output (torch.Tensor, optional): The mid tensor holds the output of layernorm. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ hidden_states, residual = self.input_layernorm(hidden_states, norm_output, residual, use_cuda_kernel) @@ -176,14 +195,16 @@ def llama_decoder_layer_forward( block_tables=block_tables, k_cache=k_cache, v_cache=v_cache, - is_prompts=is_prompts, sequence_lengths=sequence_lengths, - kv_seq_len=kv_seq_len, cos_sin=cos_sin, fd_inter_tensor=fd_inter_tensor, + is_prompts=is_prompts, + kv_seq_len=kv_seq_len, output_tensor=output_tensor, sm_scale=sm_scale, use_cuda_kernel=use_cuda_kernel, + cu_seqlens=cu_seqlens, + high_precision=high_precision, ) # Fully Connected @@ -277,43 +298,48 @@ class NopadLlamaAttention(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - block_tables: torch.Tensor = None, - k_cache: torch.Tensor = None, - v_cache: torch.Tensor = None, + block_tables: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + sequence_lengths: torch.Tensor, + cos_sin: Tuple[torch.Tensor], + fd_inter_tensor: FDIntermTensors, is_prompts: bool = True, - sequence_lengths: torch.Tensor = None, kv_seq_len: int = 0, - cos_sin: Tuple[torch.Tensor] = None, - fd_inter_tensor: FDIntermTensors = None, output_tensor: torch.Tensor = None, sm_scale: int = None, use_cuda_kernel: bool = True, + cu_seqlens: torch.Tensor = None, + high_precision: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """ Args: hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim]. - block_tables (torch.Tensor, optional): A 2D tensor of shape [batch_size, max_blocks_per_sequence], - storing mapping of token_position_id -> block_id. Defaults to None. - k_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - v_cache (torch.Tensor, optional): It holds the GPU memory for the key cache. Defaults to None. - is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. - sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. Defaults to None. - kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. - cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. Defaults to None. + block_tables (torch.Tensor): A 2D tensor of shape [batch_size, max_blocks_per_sequence], + storing mapping of token_position_id -> block_id. + k_cache (torch.Tensor): It holds the GPU memory for the key cache. + v_cache (torch.Tensor): It holds the GPU memory for the key cache. + sequence_lengths (torch.Tensor, optional): Holding the sequence length of each sequence. + cos_sin (Tuple[torch.Tensor], optional): Holding cos and sin. fd_inter_tensor (FDIntermTensors, optional): Holding tensors used for - storing intermediate values in flash-decoding. Defaults to None. + storing intermediate values in flash-decoding. + is_prompts (bool, optional): Whether the current inference process is in the context input phase. Defaults to True. + kv_seq_len (int, optional): The max sequence length of input sequences. Defaults to 0. output_tensor (torch.Tensor, optional): The mid tensor holds the output of attention. Defaults to None. sm_scale (int, optional): Used for flash attention. Defaults to None. use_cuda_kernel: (bool, optional): Whether to use cuda kernel. Defaults to True. + cu_seqlens(torch.Tensor, optional): Holding the cumulative sum of sequence length. + high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ + token_nums = hidden_states.size(0) + if self.num_heads != self.num_key_value_heads: query_states = torch.mm(hidden_states, self.q_proj_weight).view(-1, self.num_heads, self.head_dim) key_states = torch.mm(hidden_states, self.k_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) value_states = torch.mm(hidden_states, self.v_proj_weight).view(-1, self.num_key_value_heads, self.head_dim) else: # fused qkv - token_nums = hidden_states.size(0) hidden_states = hidden_states.expand(3, -1, -1) query_states, key_states, value_states = ( torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) @@ -322,23 +348,41 @@ class NopadLlamaAttention(LlamaAttention): block_size = k_cache.size(-2) if is_prompts: - if use_cuda_kernel: - inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + if use_cuda_kernel and query_states.dtype != torch.float32 and use_flash_attn2: + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) else: rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - attn_output = context_attention_unpadded( - q=query_states, - k=key_states, - v=value_states, - k_cache=k_cache, - v_cache=v_cache, - context_lengths=sequence_lengths, - block_tables=block_tables, - block_size=block_size, - output=output_tensor, - max_seq_len=kv_seq_len, - sm_scale=sm_scale, - ) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) else: if use_cuda_kernel: inference_ops.rotary_embedding_and_cache_copy( @@ -351,6 +395,7 @@ class NopadLlamaAttention(LlamaAttention): v_cache, sequence_lengths, block_tables, + high_precision, ) else: decoding_fused_rotary_embedding( @@ -436,6 +481,5 @@ class NopadLlamaMLP(LlamaMLP): """ hidden_states = hidden_states.expand(2, -1, -1) gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight) - act_out = torch.nn.functional.silu(gate_up_proj_out[0], inplace=True) - tmp_out = act_out * gate_up_proj_out[1] - return torch.mm(tmp_out, self.down_proj_weight) + act_out = inference_ops.silu_and_mul(gate_up_proj_out) + return torch.mm(act_out, self.down_proj_weight) diff --git a/examples/inference/benchmark_llama.py b/examples/inference/benchmark_llama.py index a6cbf2ee1..448a84c6f 100644 --- a/examples/inference/benchmark_llama.py +++ b/examples/inference/benchmark_llama.py @@ -136,7 +136,8 @@ def benchmark_inference(args): data = data_gen(mbsz, args.seq_len) - data = data.tolist() + if args.mode == "colossalai" or args.mode == "vllm": + data = data.tolist() generation_config = GenerationConfig( pad_token_id=tokenizer.pad_token_id, diff --git a/extensions/csrc/common/micros.h b/extensions/csrc/common/micros.h index c2241029f..5400a6dc1 100644 --- a/extensions/csrc/common/micros.h +++ b/extensions/csrc/common/micros.h @@ -56,6 +56,23 @@ AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \ } +#define DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION(HIGH_PRECISION, \ + TYPE, NAME, ...) \ + switch (HIGH_PRECISION) { \ + case false: { \ + const bool high_precision = false; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + break; \ + } \ + case true: { \ + const bool high_precision = true; \ + DISPATCH_FLOAT_HALF_AND_BFLOAT(TYPE, NAME, __VA_ARGS__); \ + break; \ + } \ + default: \ + AT_ERROR("HIGH_PRECISION must be bool, but get ", HIGH_PRECISION, "."); \ + } + #define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \ switch (TYPEIN) { \ case at::ScalarType::Float: { \ diff --git a/extensions/csrc/common/mp_type_traits.h b/extensions/csrc/common/mp_type_traits.h index 2a767620a..77de7c12a 100644 --- a/extensions/csrc/common/mp_type_traits.h +++ b/extensions/csrc/common/mp_type_traits.h @@ -27,5 +27,18 @@ struct MPTypeTrait { using Type = float; }; +template +struct ScalarTypeTrait; + +template +struct ScalarTypeTrait { + using Type = typename MPTypeTrait::Type; +}; + +template +struct ScalarTypeTrait { + using Type = T; +}; + } // namespace common } // namespace colossalAI diff --git a/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu new file mode 100644 index 000000000..3f6adc018 --- /dev/null +++ b/extensions/csrc/cuda/context_kv_cache_memcpy_kernel.cu @@ -0,0 +1,195 @@ +#include +#include + +#include "utils/vector_copy_utils.h" +#include "../common/micros.h" + +template +__global__ void context_kv_cache_memcpy_kernel( + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + scalar_t* __restrict__ key_cache, + scalar_t* __restrict__ value_cache, + const int* __restrict__ sequence_lengths, + const int* __restrict__ cu_seqlens, + const int* __restrict__ block_tables, + const int head_num, + const int head_dim, + const int block_size, + const int batch_size, + const int block_table_stride, + const int64_t key_stride, + const int64_t value_stride +) +{ + const int seq_token_id = blockIdx.x; + const int seq_id = blockIdx.y; + const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size]; + + if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) { + return ; + } + + const int block_offset = seq_token_id % block_size; + const int hidden_size = head_num * head_dim; + const int total_token_id = cu_seqlens[seq_id] + seq_token_id; + int head_id; + int head_offset; + int64_t key_src_id; + int64_t value_src_id; + int64_t target_id; + + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + copy_vector(key_cache + target_id, key + key_src_id); + copy_vector(value_cache + target_id, value + value_src_id); + } + + // tail process + for (; i < hidden_size; ++i ) { + head_id = i / head_dim; + head_offset = i % head_dim; + key_src_id = total_token_id * key_stride + i; + value_src_id = total_token_id * value_stride + i; + target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + +} + +template +void apply_context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + int num_tokens = key.size(0); + int head_num = key.size(1); + int head_dim = key.size(2); + int block_size = key_cache.size(2); + int batch_size = block_tables.size(0); + + int64_t key_stride = key.stride(0); + int64_t value_stride = value.stride(0); + int block_table_stride = block_tables.stride(0); + + int vec_size = get_vec_size(key); + + if (head_dim % vec_size != 0) { + // Disable vectorized loading optimization when head_dim is not divisible by VecSize. + vec_size = 1; + } + + int thread_nums = head_num * head_dim / vec_size; + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(max_seq_len_in_batch, batch_size); + dim3 block(std::min(thread_nums, 512)); + + switch (vec_size) { + case 1: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + case 2: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + case 4: + context_kv_cache_memcpy_kernel<<>>( + key.data_ptr(), + value.data_ptr(), + key_cache.data_ptr(), + value_cache.data_ptr(), + sequence_lengths.data_ptr(), + cu_seqlens.data_ptr(), + block_tables.data_ptr(), + head_num, + head_dim, + block_size, + batch_size, + block_table_stride, + key_stride, + value_stride + ); + break; + default: + AT_ERROR("Unsupported vectorized size ", vec_size); + break; + } + + AT_CUDA_CHECK(cudaGetLastError()); + +} + +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch) +{ + DISPATCH_FLOAT_HALF_AND_BFLOAT( + key.scalar_type(), + "context_kv_cache_memcpy", + apply_context_kv_cache_memcpy( + key, + value, + key_cache, + value_cache, + sequence_lengths, + cu_seqlens, + block_tables, + max_seq_len_in_batch + );) +} diff --git a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu index 3b1197a91..08889b236 100644 --- a/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu +++ b/extensions/csrc/cuda/decode_kv_cache_memcpy_kernel.cu @@ -30,7 +30,9 @@ __global__ void decode_kv_cache_memcpy_kernel( return ; } - for (int i = threadIdx.x * VecSize; i < hidden_size; i += blockDim.x * VecSize) { + int i = threadIdx.x * VecSize; + + for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) { const int head_id = i / head_dim; const int head_offset = i % head_dim; const int64_t key_src_id = seq_id * key_stride + i; @@ -43,6 +45,19 @@ __global__ void decode_kv_cache_memcpy_kernel( copy_vector(value_cache + target_id, value + value_src_id); } + for (; i < hidden_size; ++i ) { + const int head_id = i / head_dim; + const int head_offset = i % head_dim; + const int64_t key_src_id = seq_id * key_stride + i; + const int64_t value_src_id = seq_id * value_stride + i; + const int64_t target_id = block_id * hidden_size * block_size + + head_id * block_size * head_dim + + block_offset * head_dim + head_offset; + + key_cache[target_id] = key[key_src_id]; + value_cache[target_id] = value[value_src_id]; + } + } template diff --git a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu index 697dc7110..8feb6b343 100644 --- a/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu +++ b/extensions/csrc/cuda/fused_rotary_emb_and_cache_kernel.cu @@ -1,14 +1,15 @@ - +// in transformers source code, huggingface uses fp16 to compute rope so we follow the same precision #include #include #include "utils/vector_copy_utils.h" #include "../common/micros.h" +#include "../common/mp_type_traits.h" -template +template __device__ void apply_emb_rotary_compute( - scalar_t* __restrict__ src, const scalar_t* __restrict__ cos_ptr, - const scalar_t* __restrict__ sin_ptr, const int64_t stride, + scalar_t* __restrict__ src, const m_scalar_t* __restrict__ cos_ptr, + const m_scalar_t* __restrict__ sin_ptr, const int64_t stride, const int token_id, const int shard_block_size, const int half_head_dim, const int head_num, const int head_dim) { scalar_t x[VecSize]; @@ -30,10 +31,10 @@ __device__ void apply_emb_rotary_compute( #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - - y[j] * sin_ptr[j * 32 + shard_offset]; - out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + - x[j] * sin_ptr[j * 32 + shard_offset]; + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); } copy_vector(src + addr_offset, out_x); @@ -62,10 +63,10 @@ __device__ void apply_kv_memcopy( } } -template +template __device__ void cos_sin_memory_access( const scalar_t* __restrict__ cos, const scalar_t* __restrict__ sin, - scalar_t* cos_ptr, scalar_t* sin_ptr, const int token_id, + m_scalar_t* cos_ptr, m_scalar_t* sin_ptr, const int token_id, const int shard_block_size, const int cos_stride, const int sin_stride, const int half_head_dim) { for (int i = threadIdx.x; i < half_head_dim; i += blockDim.x) { @@ -73,16 +74,16 @@ __device__ void cos_sin_memory_access( const int shard_offset = (i % shard_block_size) / VecSize; const int shard_head = (i / shard_block_size) * shard_block_size + i % VecSize * 32; - cos_ptr[shard_head + shard_offset] = cos[token_id * cos_stride + i]; - sin_ptr[shard_head + shard_offset] = sin[token_id * sin_stride + i]; + cos_ptr[shard_head + shard_offset] = static_cast(cos[token_id * cos_stride + i]); + sin_ptr[shard_head + shard_offset] = static_cast(sin[token_id * sin_stride + i]); } } -template +template __device__ void apply_k_rotary_emb_compute( scalar_t* __restrict__ key, scalar_t* __restrict__ value, scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache, - const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, + const m_scalar_t* __restrict__ cos_ptr, const m_scalar_t* __restrict__ sin_ptr, const int* __restrict__ sequence_lengths, const int* __restrict__ block_tables, const int64_t key_stride, const int64_t value_stride, const int token_id, @@ -120,10 +121,10 @@ __device__ void apply_k_rotary_emb_compute( #pragma unroll for (int j = 0; j < VecSize; j++) { - out_x[j] = x[j] * cos_ptr[j * 32 + shard_offset] - - y[j] * sin_ptr[j * 32 + shard_offset]; - out_y[j] = y[j] * cos_ptr[j * 32 + shard_offset] + - x[j] * sin_ptr[j * 32 + shard_offset]; + out_x[j] = static_cast(static_cast(x[j]) * cos_ptr[j * 32 + shard_offset] - + static_cast(y[j]) * sin_ptr[j * 32 + shard_offset]); + out_y[j] = static_cast(static_cast(y[j]) * cos_ptr[j * 32 + shard_offset] + + static_cast(x[j]) * sin_ptr[j * 32 + shard_offset]); } copy_vector(key_cache + target_id, out_x); @@ -137,7 +138,7 @@ __device__ void apply_k_rotary_emb_compute( block_size, block_offset, head_dim, half_head_dim); } -template +template __global__ void rotary_embedding_and_cache_copy_kernel( scalar_t* __restrict__ query, scalar_t* __restrict__ key, @@ -167,21 +168,21 @@ __global__ void rotary_embedding_and_cache_copy_kernel( extern __shared__ char shard_ptr[]; - scalar_t *cos_ptr = (scalar_t*)shard_ptr; - scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key and copy kv - apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); + apply_k_rotary_emb_compute(key, value, key_cache, value_cache, cos_ptr, sin_ptr, sequence_lengths, block_tables, key_stride, value_stride, token_id, block_table_stride, head_num, head_dim, kv_head_num, block_size, half_head_dim, shard_block_size); } -template +template __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, scalar_t* __restrict__ key, @@ -202,21 +203,21 @@ __global__ void rotary_embedding_kernel( extern __shared__ char shard_ptr[]; - scalar_t *cos_ptr = (scalar_t*)shard_ptr; - scalar_t *sin_ptr = cos_ptr + half_shard_element_num; + m_scalar_t *cos_ptr = (m_scalar_t*)shard_ptr; + m_scalar_t *sin_ptr = cos_ptr + half_shard_element_num; // apply cos_sin memcopy - cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); + cos_sin_memory_access(cos, sin, cos_ptr, sin_ptr, token_id, shard_block_size, cos_stride, sin_stride, half_head_dim); __syncthreads(); //compute query - apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); + apply_emb_rotary_compute(query, cos_ptr, sin_ptr, query_stride, token_id, shard_block_size, half_head_dim, head_num, head_dim); //compute key - apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); + apply_emb_rotary_compute(key, cos_ptr, sin_ptr, key_stride, token_id, shard_block_size, half_head_dim, kv_head_num, head_dim); } -template +template void apply_rotary_embedding_and_cache_copy( at::Tensor& query, // [num_tokens, head_num, head_dim] at::Tensor& key, // [num_tokens, kv_head_num, head_dim] @@ -241,6 +242,8 @@ void apply_rotary_embedding_and_cache_copy( int sin_stride = sin.stride(0); int block_table_stride = block_tables.stride(0); + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { @@ -259,7 +262,7 @@ void apply_rotary_embedding_and_cache_copy( switch (vec_size) { case 1: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -283,7 +286,7 @@ void apply_rotary_embedding_and_cache_copy( ); break; case 2: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -307,7 +310,7 @@ void apply_rotary_embedding_and_cache_copy( ); break; case 4: - rotary_embedding_and_cache_copy_kernel<<>>( + rotary_embedding_and_cache_copy_kernel<<>>( query.data_ptr(), key.data_ptr(), value.data_ptr(), @@ -338,12 +341,12 @@ void apply_rotary_embedding_and_cache_copy( AT_CUDA_CHECK(cudaGetLastError()); } -template +template void apply_rotary_embedding( at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& cos, // [total_tokens, head_dim] - at::Tensor& sin // [total_tokens, head_dim] + at::Tensor& sin // [total_tokens, head_dim] ){ int num_tokens = query.size(0); int head_num = query.size(1); @@ -355,6 +358,8 @@ void apply_rotary_embedding( int cos_stride = cos.stride(0); int sin_stride = sin.stride(0); + using m_scalar_t = typename colossalAI::common::ScalarTypeTrait::Type; + int vec_size = get_vec_size(query); if ((head_dim / 2) % vec_size != 0) { @@ -373,7 +378,7 @@ void apply_rotary_embedding( switch (vec_size) { case 1: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -389,7 +394,7 @@ void apply_rotary_embedding( ); break; case 2: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -405,7 +410,7 @@ void apply_rotary_embedding( ); break; case 4: - rotary_embedding_kernel<<>>( + rotary_embedding_kernel<<>>( query.data_ptr(), key.data_ptr(), cos.data_ptr(), @@ -436,12 +441,14 @@ void rotary_embedding_and_cache_copy( at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] at::Tensor& sequence_lengths, // [batch_size] - at::Tensor& block_tables) // [batch_size, max_seq_len] + at::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision) { - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, query.scalar_type(), "rotary_embedding_and_cache_copy", - apply_rotary_embedding_and_cache_copy( + apply_rotary_embedding_and_cache_copy( query, key, value, @@ -458,12 +465,14 @@ void rotary_embedding( at::Tensor& query, // [total_tokens, head_num, head_dim] at::Tensor& key, // [total_tokens, kv_head_num, head_dim] at::Tensor& cos, // [total_tokens, head_dim] - at::Tensor& sin // [total_tokens, head_dim] + at::Tensor& sin, // [total_tokens, head_dim] + bool high_precision ){ - DISPATCH_FLOAT_HALF_AND_BFLOAT( + DISPATCH_FLOAT_HALF_AND_BFLOAT_WITH_HIGH_PRECISION( + high_precision, query.scalar_type(), "rotary_embedding", - apply_rotary_embedding( + apply_rotary_embedding( query, key, cos, diff --git a/extensions/csrc/cuda/pybind/inference.cpp b/extensions/csrc/cuda/pybind/inference.cpp index 4282f5382..541146e3a 100644 --- a/extensions/csrc/cuda/pybind/inference.cpp +++ b/extensions/csrc/cuda/pybind/inference.cpp @@ -9,11 +9,22 @@ void decode_kv_cache_memcpy( torch::Tensor& sequence_lengths, // [batch_size] torch::Tensor& block_tables); // [batch_size, max_seq_len] +void context_kv_cache_memcpy( + at::Tensor& key, // [num_tokens, head_num, head_dim] + at::Tensor& value, // [num_tokens, head_num, head_dim] + at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim] + at::Tensor& sequence_lengths, // [batch_size] + at::Tensor& cu_seqlens, // [batch_size + 1] + at::Tensor& block_tables, // [batch_size, max_seq_len] + int max_seq_len_in_batch); + void rotary_embedding( torch::Tensor& query, // [total_tokens, head_num, head_dim] torch::Tensor& key, // [total_tokens, kv_head_num, head_dim] torch::Tensor& cos, // [total_tokens, head_dim] - torch::Tensor& sin); // [total_tokens, head_dim] + torch::Tensor& sin, // [total_tokens, head_dim] + bool high_precision); void rotary_embedding_and_cache_copy( torch::Tensor& query, // [num_tokens, head_num, head_dim] @@ -25,7 +36,9 @@ void rotary_embedding_and_cache_copy( torch::Tensor& value_cache, // [num_blocks, num_heads, block_size, head_dim] torch::Tensor& sequence_lengths, // [batch_size] - torch::Tensor& block_tables); // [batch_size, max_seq_len] + torch::Tensor& block_tables, // [batch_size, max_seq_len] + bool high_precision); + torch::Tensor silu_and_mul(const torch::Tensor& ins); void rms_layernorm(torch::Tensor& out, // [..., hidden_size] @@ -42,6 +55,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy, "Copy the GPU memory of kvcache during the decode stage."); + m.def("context_kv_cache_memcpy", &context_kv_cache_memcpy, + "Copy the GPU memory of kvcache during the context stage."); + m.def( "rotary_embedding_and_cache_copy", &rotary_embedding_and_cache_copy, "performing Rotary Embedding-related calculations and KVCache Memcopy."); diff --git a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h index 524ef46c6..bd2465bea 100644 --- a/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h +++ b/extensions/csrc/cuda/scaled_upper_triang_masked_softmax.h @@ -11,6 +11,8 @@ #include #include +#include "utils/vector_copy_utils.h" + namespace { int log2_ceil(int value) { diff --git a/extensions/csrc/cuda/utils/vector_copy_utils.h b/extensions/csrc/cuda/utils/vector_copy_utils.h index 3c3afa0b3..5157ec738 100644 --- a/extensions/csrc/cuda/utils/vector_copy_utils.h +++ b/extensions/csrc/cuda/utils/vector_copy_utils.h @@ -11,16 +11,16 @@ template __device__ __inline__ void copy_vector(T *dst, const T *src) { using VT = typename colossalAI::cuda::utils::VecTypeTrait::Type; // Note(LiuYang): Here static_cast can't be used for cast between two pointer - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } template <> __device__ __inline__ void copy_vector(float *dst, const float *src) { // Since the maximum memory alignment length is 128 bits, we choose float4 // here. - *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); *(reinterpret_cast(dst + 4)) = - *(reinterpret_cast(src + 4)); + *(reinterpret_cast(src + 4)); } template diff --git a/extensions/inference/inference_ops_cuda.py b/extensions/inference/inference_ops_cuda.py index ae3754ca7..4e0afc819 100644 --- a/extensions/inference/inference_ops_cuda.py +++ b/extensions/inference/inference_ops_cuda.py @@ -12,6 +12,7 @@ class InferenceOpsCudaExtension(_CudaExtension): for fname in [ "cuda/pybind/inference.cpp", "cuda/decode_kv_cache_memcpy_kernel.cu", + "cuda/context_kv_cache_memcpy_kernel.cu", "cuda/fused_rotary_emb_and_cache_kernel.cu", "cuda/activation_kernel.cu", "cuda/rms_layernorm_kernel.cu", diff --git a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py index d5259a596..3fa17037f 100644 --- a/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py +++ b/tests/test_infer/test_ops/cuda/test_kv_cache_memcpy.py @@ -1,8 +1,10 @@ import pytest import torch +import torch.nn.functional as F from colossalai.kernel.kernel_loader import InferenceOpsLoader from colossalai.utils import get_current_device +from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2 from tests.test_infer.test_ops.triton.test_kvcache_copy import prepare_data inference_ops = InferenceOpsLoader().load() @@ -10,12 +12,7 @@ inference_ops = InferenceOpsLoader().load() HEAD_DIM = 4 -@pytest.mark.parametrize("bsz", [4, 7, 32]) -@pytest.mark.parametrize("block_size", [16, 32, 64]) -@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) -@pytest.mark.parametrize("num_kv_heads", [16]) -@pytest.mark.parametrize("same_context_len", [True, False]) -def test_copy_kv_to_caches( +def run_decode_copy_kv_to_caches( bsz: int, block_size: int, max_num_blocks_per_seq: int, @@ -61,5 +58,65 @@ def test_copy_kv_to_caches( assert torch.equal(v_target, v_source) +def run_context_copy_kv_to_cache( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + torch.manual_seed(123) + + assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads." + max_seq_len = max_num_blocks_per_seq * block_size + dtype = torch.float16 + device = get_current_device() + + if same_context_len: + context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) + else: + context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device) + + num_tokens = torch.sum(context_lengths).item() + + max_seq_len_in_batch = context_lengths.max() + cu_seqlens = F.pad(torch.cumsum(context_lengths, dim=0, dtype=torch.torch.int32), (1, 0)) + + kv_size = (num_tokens, num_kv_heads, HEAD_DIM) + key = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + value = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) + + k_cache_ref, v_cache_ref, block_tables = generate_caches_and_block_tables_v2( + key, value, context_lengths, bsz, max_num_blocks_per_seq, block_size, dtype, device + ) + + block_tables = block_tables.to(device=device) + k_cache = torch.zeros_like(k_cache_ref) + v_cache = torch.zeros_like(v_cache_ref) + + inference_ops.context_kv_cache_memcpy( + key, value, k_cache, v_cache, context_lengths, cu_seqlens, block_tables, max_seq_len_in_batch + ) + + assert torch.equal(k_cache, k_cache_ref) + assert torch.equal(v_cache, v_cache_ref) + + +@pytest.mark.parametrize("bsz", [4, 7, 32]) +@pytest.mark.parametrize("block_size", [16, 32, 64]) +@pytest.mark.parametrize("max_num_blocks_per_seq", [8, 32]) +@pytest.mark.parametrize("num_kv_heads", [16]) +@pytest.mark.parametrize("same_context_len", [True, False]) +def test_kv_cache_memcopy( + bsz: int, + block_size: int, + max_num_blocks_per_seq: int, + num_kv_heads: int, + same_context_len: bool, +): + run_context_copy_kv_to_cache(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + run_decode_copy_kv_to_caches(bsz, block_size, max_num_blocks_per_seq, num_kv_heads, same_context_len) + + if __name__ == "__main__": - test_copy_kv_to_caches(4, 32, 8, 16, True) + test_kv_cache_memcopy(4, 32, 8, 16, True) diff --git a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py index b9c0a3269..9e0a8b0db 100644 --- a/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/cuda/test_rotary_embdding_unpad.py @@ -1,3 +1,4 @@ +import numpy as np import pytest import torch from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb @@ -10,11 +11,18 @@ from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("BATCH_SIZE", [4]) @pytest.mark.parametrize("SEQ_LEN", [64]) @pytest.mark.parametrize("H", [32]) @pytest.mark.parametrize("D", [64]) -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): torch.manual_seed(10) TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN @@ -54,17 +62,36 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): kv_seq_lengths = past_kv_seq_lengths + 1 block_tables = block_tables.to(device="cuda") - q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) - k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) new_q_copy = new_q.clone() new_k_copy = new_k.clone() + if dtype == torch.float16: + rtol = 1e-3 + atol = 1e-3 + + new_q_fp16 = new_q.clone() + new_k_fp16 = new_k.clone() + + high_precision_cos = cos[:BATCH_SIZE].to(torch.float32) + high_precision_sin = sin[:BATCH_SIZE].to(torch.float32) + high_precision_q = new_q.to(torch.float32) + high_precision_k = new_k.to(torch.float32) + q_ref = torch_rotary_emb(high_precision_q, high_precision_cos, high_precision_sin).to(torch.float16) + k_ref = torch_rotary_emb(high_precision_k, high_precision_cos, high_precision_sin).to(torch.float16) + + else: + rtol = 1e-5 + atol = 1e-7 + + q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE]) + inference_ops.rotary_embedding_and_cache_copy( - new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables + new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables, True ) - inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin) + inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin, True) past_kv_seq_len = kv_seq_lengths - 1 target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size] @@ -74,18 +101,26 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze() v_source = new_v.squeeze() - assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6) - assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6) + numpy_allclose(new_q, q_ref, rtol=rtol, atol=atol) + numpy_allclose(k_target, k_ref, rtol=rtol, atol=atol) - assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6) - assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6) + numpy_allclose(new_q_copy, q_ref, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, k_ref, rtol=rtol, atol=atol) assert k_target.shape == k_source.shape - assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6) + numpy_allclose(k_target, k_source, rtol=rtol, atol=atol) assert v_target.shape == v_source.shape assert torch.equal(v_target, v_source) + if dtype == torch.float16: + # After testing cuda fp16 high_precision, it was found to have higher precision than torch fp16. Therefore, the threshold here has been relaxed to pass the test. + rtol = 1e-3 + atol = 1e-1 + inference_ops.rotary_embedding(new_q_fp16, new_k_fp16, cos, sin, False) + numpy_allclose(new_q_copy, new_q_fp16, rtol=rtol, atol=atol) + numpy_allclose(new_k_copy, new_k_fp16, rtol=rtol, atol=atol) + if __name__ == "__main__": - test_rotary_emb(16, 512, 4, 128, torch.float16) + test_rotary_emb(16, 64, 4, 128, torch.float16)