diff --git a/colossalai/inference/README.md b/colossalai/inference/README.md index d0c281e05..4aca7aeb0 100644 --- a/colossalai/inference/README.md +++ b/colossalai/inference/README.md @@ -34,11 +34,13 @@ In this section we discuss how the colossal inference works and integrates with - [x] policy - [x] context forward - [x] token forward + - [x] support flash-decoding - [ ] Replace the kernels with `faster-transformer` in token-forward stage - [ ] Support all models - [x] Llama + - [x] Llama-2 - [x] Bloom - - [ ] Chatglm2 + - [x] Chatglm2 - [ ] Benchmarking for all models ## Get started @@ -68,6 +70,12 @@ git clone https://github.com/ModelTC/lightllm git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . + +# also, install xformers from source: +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers + ``` ### Docker @@ -89,7 +97,10 @@ git checkout 28c1267cfca536b7b4f28e921e03de735b003039 cd lightllm pip3 install -e . - +# install xformers from source +pip install ninja +# Set TORCH_CUDA_ARCH_LIST if running and building on different GPU types +pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers ``` ### Dive into fast-inference! diff --git a/colossalai/inference/tensor_parallel/engine.py b/colossalai/inference/tensor_parallel/engine.py index e410532d8..1c203140c 100644 --- a/colossalai/inference/tensor_parallel/engine.py +++ b/colossalai/inference/tensor_parallel/engine.py @@ -311,6 +311,7 @@ class TPInferEngine: seq_start_indexes[i] = start_index start_index += curr_seq_len max_len_in_batch = curr_seq_len if curr_seq_len > max_len_in_batch else max_len_in_batch + block_loc = torch.empty((batch_size, self.max_input_len + self.max_output_len), dtype=torch.long, device="cuda") batch_infer_state = BatchInferState(batch_size, max_len_in_batch) batch_infer_state.seq_len = seq_lengths.to("cuda") diff --git a/colossalai/inference/tensor_parallel/modeling/bloom.py b/colossalai/inference/tensor_parallel/modeling/bloom.py index d84c567ea..0ad3994b0 100644 --- a/colossalai/inference/tensor_parallel/modeling/bloom.py +++ b/colossalai/inference/tensor_parallel/modeling/bloom.py @@ -19,6 +19,12 @@ from transformers.utils import logging from colossalai.inference.tensor_parallel.batch_infer_state import BatchInferState from colossalai.kernel.triton import bloom_context_attn_fwd, copy_kv_cache_to_dest, token_attention_fwd +try: + from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_bloom_context_attention_fwd + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + def generate_alibi(n_head, dtype=torch.float16): """ @@ -460,7 +466,10 @@ class BloomInferenceForwards: # output = self.output[:batch_size*q_length, :, :] output = torch.empty_like(q) - bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) + if HAS_LIGHTLLM_KERNEL: + lightllm_bloom_context_attention_fwd(q, k, v, output, alibi, b_start_loc, b_seq_len, max_input_len) + else: + bloom_context_attn_fwd(q, k, v, output, b_start_loc, b_seq_len, max_input_len, alibi) context_layer = output.view(batch_size, q_length, H * D_HEAD) else: diff --git a/colossalai/inference/tensor_parallel/modeling/llama.py b/colossalai/inference/tensor_parallel/modeling/llama.py index a17b901dc..8573bb965 100644 --- a/colossalai/inference/tensor_parallel/modeling/llama.py +++ b/colossalai/inference/tensor_parallel/modeling/llama.py @@ -1,4 +1,6 @@ from typing import List, Optional, Tuple +import math +import copy import torch from transformers.modeling_outputs import BaseModelOutputWithPast @@ -10,24 +12,11 @@ from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttention from ._utils import copy_kv_to_mem_cache -try: - from vllm import layernorm_ops, pos_encoding_ops - - rms_norm = layernorm_ops.rms_norm - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - print( - "if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch" - ) - HAS_VLLM_KERNERL = False - try: from lightllm.models.llama2.triton_kernel.context_flashattention_nopad import ( context_attention_fwd as lightllm_llama2_context_attention_fwd, ) + from lightllm.models.llama.triton_kernel.context_flashattention_nopad import context_attention_fwd as lightllm_context_attention_fwd from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd as llama_rotary_embedding_fwd HAS_LIGHTLLM_KERNEL = True @@ -35,6 +24,13 @@ except: print("please install lightllm from source to run inference: https://github.com/ModelTC/lightllm") HAS_LIGHTLLM_KERNEL = False +try: + from flash_attn import flash_attn_with_kvcache + HAS_FLASH_KERNEL = True +except: + HAS_FLASH_KERNEL = False + print("please install flash attentiom from https://github.com/Dao-AILab/flash-attention") + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -54,6 +50,71 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed +def llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=1): + if num_key_value_groups == 1: + if HAS_LIGHTLLM_KERNEL is False: + llama_context_attn_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + lightllm_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernels to run llama2 model" + lightllm_llama2_context_attention_fwd( + query_states, + key_states, + value_states, + attn_output, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + +def llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=1): + assert HAS_LIGHTLLM_KERNEL is True, "You have to install lightllm kernel to run token attention for llama models" + if num_key_value_groups == 1: + token_attention_fwd( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + ) + else: + Llama2TokenAttentionForwards.token_attn( + query_states, + infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], + infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], + attn_output, + infer_state.block_loc, + infer_state.start_loc, + infer_state.seq_len, + # infer_state.cache_manager.past_key_values_length, + infer_state.max_len_in_batch, + infer_state.other_kv_index, + ) + class LlamaInferenceForwards: """ @@ -204,7 +265,8 @@ class LlamaInferenceForwards: hidden_states=all_hidden_states, attentions=all_self_attns, ) - + + @staticmethod def llama_decoder_layer_forward( self: LlamaDecoderLayer, @@ -247,6 +309,7 @@ class LlamaInferenceForwards: outputs += (present_key_value,) return outputs + @staticmethod def llama_flash_attn_kvcache_forward( @@ -295,27 +358,8 @@ class LlamaInferenceForwards: infer_state.cache_manager, ) attn_output = torch.empty_like(query_states) - - if self.num_key_value_groups == 1: - llama_context_attn_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) - else: - lightllm_llama2_context_attention_fwd( - query_states, - key_states, - value_states, - attn_output, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + + llama_triton_context_attention(query_states, key_states, value_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) else: if infer_state.decode_is_contiguous: # if decode is contiguous, then we copy to key cache and value cache in cache manager directly @@ -337,35 +381,26 @@ class LlamaInferenceForwards: infer_state.decode_mem_index, infer_state.cache_manager, ) - - # second token and follows - # kv = torch.stack((key_states, value_states), dim=2) - # (batch_size, seqlen, nheads, headdim) - attn_output = torch.empty_like(query_states) - - if self.num_key_value_groups == 1: - token_attention_fwd( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - ) + + HAS_LIGHTLLM_KERNEL = False + if HAS_LIGHTLLM_KERNEL: + attn_output = torch.empty_like(query_states) + llama_triton_token_attention(query_states, attn_output, infer_state, num_key_value_groups=self.num_key_value_groups) else: - Llama2TokenAttentionForwards.token_attn( - query_states, - infer_state.cache_manager.key_buffer[infer_state.decode_layer_id], - infer_state.cache_manager.value_buffer[infer_state.decode_layer_id], - attn_output, - infer_state.block_loc, - infer_state.start_loc, - infer_state.seq_len, - infer_state.max_len_in_batch, - infer_state.other_kv_index, - ) + heads_per_group = self.num_heads // self.num_key_value_heads + cache_k = infer_state.cache_manager.key_buffer[infer_state.decode_layer_id] + cache_v = infer_state.cache_manager.value_buffer[infer_state.decode_layer_id] + + query_states = query_states.view(bsz, -1, self.num_heads, self.head_dim) + copy_cache_k = cache_k.view(bsz, -1, self.num_key_value_heads, self.head_dim) + copy_cache_v = cache_v.view(bsz, -1, self.num_key_value_heads, self.head_dim) + + attn_output = flash_attn_with_kvcache(q = query_states, + k_cache = copy_cache_k, + v_cache = copy_cache_v, + softmax_scale = 1/ math.sqrt(self.head_dim), + causal = True) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) @@ -374,22 +409,3 @@ class LlamaInferenceForwards: # return past_key_value as None return attn_output, None, None - -def get_llama_vllm_rmsnorm_forward(): - if HAS_VLLM_KERNERL: - - def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - self.weight.data, - self.variance_epsilon, - ) - - return out - - return _vllm_rmsnorm_forward - else: - return None diff --git a/colossalai/inference/tensor_parallel/policies/llama.py b/colossalai/inference/tensor_parallel/policies/llama.py index 7e163efe0..d6c072c74 100644 --- a/colossalai/inference/tensor_parallel/policies/llama.py +++ b/colossalai/inference/tensor_parallel/policies/llama.py @@ -9,7 +9,7 @@ from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy from ..modeling._utils import init_to_get_rotary -from ..modeling.llama import LlamaInferenceForwards, get_llama_vllm_rmsnorm_forward +from ..modeling.llama import LlamaInferenceForwards try: from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward as lightllm_rmsnorm_forward @@ -105,9 +105,6 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): infer_forward = None if HAS_TRITON_RMSNORM: infer_forward = get_triton_rmsnorm_forward() - else: - # NOTE: adding rms_norm from cuda kernels caused precision issue, fix @tiandiao123 - infer_forward = get_llama_vllm_rmsnorm_forward() if infer_forward is not None: method_replacement = {"forward": partial(infer_forward)} diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 1b4f6e44b..5ce6f2c21 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -5,7 +5,6 @@ import torch try: import triton import triton.language as tl - HAS_TRITON = True except ImportError: HAS_TRITON = False @@ -155,39 +154,43 @@ if HAS_TRITON: num_warps = 4 if Lk <= 64 else 8 tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - alibi, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - # manually setting this blcok num, we can use tuning config to futher speed-up - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + alibi, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return @torch.no_grad() @@ -207,36 +210,40 @@ if HAS_TRITON: tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 # num_warps = 4 - _context_flash_attention_kernel[grid]( - q, - k, - v, - sm_scale, - b_start_loc, - b_seq_len, - tmp, - None, - o, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - tmp.stride(0), - tmp.stride(1), - tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) + if triton.__version__ < "2.1.0": + _context_flash_attention_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + tmp, + None, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + tmp.stride(0), + tmp.stride(1), + tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + else: + raise Exception("Please install lightllm kernels from https://github.com/ModelTC/lightllm since your triton version is larger than 2.0.0") + return \ No newline at end of file diff --git a/examples/inference/bench_llama.py b/examples/inference/bench_llama.py index f3e742dfb..56bf062e2 100644 --- a/examples/inference/bench_llama.py +++ b/examples/inference/bench_llama.py @@ -105,8 +105,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-p", "--path", type=str, help="Model path", required=True) parser.add_argument("-tp", "--tp_size", type=int, default=1, help="Tensor parallel size") - parser.add_argument("-b", "--batch_size", type=int, default=16, help="Maximum batch size") - parser.add_argument("--input_len", type=int, default=256, help="Maximum input length") + parser.add_argument("-b", "--batch_size", type=int, default=32, help="Maximum batch size") + parser.add_argument("--input_len", type=int, default=1024, help="Maximum input length") parser.add_argument("--output_len", type=int, default=128, help="Maximum output length") parser.add_argument( "--test_mode", type=str, help="Test mode", default="e2e_test", choices=["e2e_test", "decoder_test"] diff --git a/tests/test_infer/test_bloom_infer.py b/tests/test_infer/test_bloom_infer.py index ba978ad9b..d4366758d 100644 --- a/tests/test_infer/test_bloom_infer.py +++ b/tests/test_infer/test_bloom_infer.py @@ -10,6 +10,12 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + TP_SIZE = 2 MAX_BATCH_SIZE = 4 MAX_INPUT_LEN = 16 @@ -52,7 +58,7 @@ def check_bloom(rank, world_size, port): run() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_chatglm2_infer.py b/tests/test_infer/test_chatglm2_infer.py index f9f7670c4..09bb8a949 100644 --- a/tests/test_infer/test_chatglm2_infer.py +++ b/tests/test_infer/test_chatglm2_infer.py @@ -12,6 +12,12 @@ from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import Ch from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 1 BATCH_SIZE = 8 @@ -61,7 +67,7 @@ def check_chatglm2(rank, world_size, port): run_chatglm2_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama2_infer.py b/tests/test_infer/test_llama2_infer.py index 0eebed889..13e7a6182 100644 --- a/tests/test_infer/test_llama2_infer.py +++ b/tests/test_infer/test_llama2_infer.py @@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -57,7 +63,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer/test_llama_infer.py b/tests/test_infer/test_llama_infer.py index b424525a3..a4f54d197 100644 --- a/tests/test_infer/test_llama_infer.py +++ b/tests/test_infer/test_llama_infer.py @@ -12,6 +12,12 @@ from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +try: + import lightllm + HAS_LIGHTLLM_KERNEL = True +except: + HAS_LIGHTLLM_KERNEL = False + os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" TPSIZE = 2 BATCH_SIZE = 8 @@ -55,7 +61,7 @@ def check_llama(rank, world_size, port): run_llama_test() -@pytest.mark.skipif(not CUDA_SUPPORT, reason="kv-cache manager engine requires cuda version to be higher than 11.5") +@pytest.mark.skipif(not CUDA_SUPPORT or not HAS_LIGHTLLM_KERNEL, reason="kv-cache manager engine requires cuda version to be higher than 11.5") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run() diff --git a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py deleted file mode 100644 index a4d893f8e..000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -import pytest -import torch -from torch import nn - -try: - from vllm import layernorm_ops - - rms_norm = layernorm_ops.rms_norm - HAS_VLLM_KERNERL = True -except: - print("please install vllm kernels to install rmsnorm") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -class LlamaRMSNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-6): - """ - LlamaRMSNorm is equivalent to T5LayerNorm - """ - super().__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.variance_epsilon = eps - - def forward(self, hidden_states): - input_dtype = hidden_states.dtype - hidden_states = hidden_states.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - return self.weight * hidden_states.to(input_dtype) - - -def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): - x = hidden_states - out = torch.empty_like(x) - rms_norm( - out, - x, - weight, - variance_epsilon, - ) - return out - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rmsnorm(): - data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") - hg_rms = LlamaRMSNorm(64) - hg_rms = hg_rms.half().cuda() - out_torch = hg_rms(data) - out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) - - check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) - assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" - - -if __name__ == "__main__": - test_rmsnorm() diff --git a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py deleted file mode 100644 index 40451ef66..000000000 --- a/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- -from typing import Tuple - -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half - -try: - from vllm import pos_encoding_ops - - rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox - HAS_VLLM_KERNERL = True -except: - print("fall back to original rotary_embedding_neox of huggingface") - print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") - HAS_VLLM_KERNERL = False - - -def rotate_half(x: torch.Tensor) -> torch.Tensor: - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb( - q: torch.Tensor, - k: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, -) -> Tuple[torch.Tensor, torch.Tensor]: - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) - return q_embed, k_embed - - -class RefRotaryEmbeddingNeox(nn.Module): - """Reference implementation of the GPT-NeoX style rotary embedding.""" - - def __init__( - self, - dim: int, - max_position_embeddings: int = 2048, - base: int = 10000, - ) -> None: - super().__init__() - self.rotary_dim = dim - self.max_position_embeddings = max_position_embeddings - - # Create cos and sin embeddings. - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2) / dim)) - t = torch.arange(max_position_embeddings).float() - freqs = torch.einsum("i,j->ij", t, inv_freq.float()) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos().to(dtype=inv_freq.dtype) - sin = emb.sin().to(dtype=inv_freq.dtype) - self.register_buffer("cos_cached", cos, persistent=False) - self.register_buffer("sin_cached", sin, persistent=False) - - def forward( - self, - positions: torch.Tensor, # [num_tokens] - query: torch.Tensor, # [num_tokens, num_heads, head_size] - key: torch.Tensor, # [num_tokens, num_heads, head_size] - ) -> Tuple[torch.Tensor, torch.Tensor]: - query_rot = query[..., : self.rotary_dim] - query_pass = query[..., self.rotary_dim :] - key_rot = key[..., : self.rotary_dim] - key_pass = key[..., self.rotary_dim :] - - query_rot = query_rot.transpose(0, 1) - key_rot = key_rot.transpose(0, 1) - cos = F.embedding(positions, self.cos_cached) - sin = F.embedding(positions, self.sin_cached) - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) - query_rot = query_rot.transpose(0, 1).contiguous() - key_rot = key_rot.transpose(0, 1).contiguous() - - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - - # Output query/key shape: [num_tokens, num_tokens, head_size] - return query, key - - -def run_rotary_embedding_neox( - num_tokens: int, - num_heads: int, - head_size: int, - max_position: int, - rotary_dim: int, - dtype: torch.dtype, - base: int = 10000, -) -> None: - positions = torch.randint(0, max_position, (num_tokens,), device="cuda") - query = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - key = torch.randn(num_tokens, num_heads * head_size, dtype=dtype, device="cuda") - - # Create the rotary embedding. - inv_freq = 1.0 / (base ** (torch.arange(0, rotary_dim, 2) / rotary_dim)) - t = torch.arange(max_position).float() - freqs = torch.einsum("i,j -> ij", t, inv_freq.float()) - cos = freqs.cos() - sin = freqs.sin() - cos_sin_cache = torch.cat((cos, sin), dim=-1) - cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda") - - # Run the kernel. The kernel is in-place, so we need to clone the inputs. - out_query = query.clone() - out_key = key.clone() - rotary_embedding_neox( - positions, - out_query, - out_key, - head_size, - cos_sin_cache, - ) - - # Run the reference implementation. - ref_rotary_embedding = RefRotaryEmbeddingNeox( - dim=rotary_dim, - max_position_embeddings=max_position, - base=base, - ).to(dtype=dtype, device="cuda") - ref_query, ref_key = ref_rotary_embedding( - positions, - query.view(num_tokens, num_heads, head_size), - key.view(num_tokens, num_heads, head_size), - ) - ref_query = ref_query.view(num_tokens, num_heads * head_size) - ref_key = ref_key.view(num_tokens, num_heads * head_size) - - # Compare the results. - assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) - assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) - - -@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") -def test_rotary_embedding(): - run_rotary_embedding_neox( - num_tokens=1024, - num_heads=8, - head_size=64, - max_position=8192, - rotary_dim=64, - dtype=torch.float16, - ) - - -if __name__ == "__main__": - test_rotary_embedding()