feat rmsnorm cuda kernel and add unittest, benchmark script (#5417)

This commit is contained in:
Steve Luo
2024-03-08 16:21:12 +08:00
committed by GitHub
parent 2b28b54ac6
commit f7aecc0c6b
8 changed files with 244 additions and 49 deletions

View File

@@ -22,15 +22,11 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
.cuda()
.half()
)
).cuda()
model = model.eval()
inputs = [
@@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)

View File

@@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 512)