Merge branch 'feature/colossal-infer' into colossal-infer-cuda-graph

This commit is contained in:
Runyu Lu
2024-03-14 10:37:05 +08:00
committed by GitHub
53 changed files with 2133 additions and 252 deletions

View File

@@ -22,15 +22,11 @@ def setup_seed(seed):
def check_inference_engine(use_engine=False, prompt_template=None):
setup_seed(20)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
model = (
LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
model = LlamaForCausalLM(
LlamaConfig(
vocab_size=50000, hidden_size=512, intermediate_size=1536, num_attention_heads=4, num_hidden_layers=16
)
.cuda()
.half()
)
).cuda()
model = model.eval()
inputs = [
@@ -44,7 +40,7 @@ def check_inference_engine(use_engine=False, prompt_template=None):
top_k = 50
if use_engine:
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template)
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
assert inference_engine.generation_config.max_new_tokens == output_len
inference_engine.add_request(prompts=inputs)

View File

@@ -0,0 +1,51 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("M", [2, 4, 8, 16])
@pytest.mark.parametrize("N", [64, 128, 512])
def test_rms_layernorm(M: int, N: int):
torch.manual_seed(123)
torch.cuda.empty_cache()
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
device = get_current_device()
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.ones(w_shape, dtype=dtype, device=device)
residual = torch.rand(x_shape, dtype=dtype, device=device)
residual_copy = residual.clone()
rms_norm = LlamaRMSNorm(hidden_size=N, eps=eps).cuda()
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
x_copy = x.clone()
y_cuda = torch.empty_like(x)
inference_ops.rms_layernorm(y_cuda, x, weight, eps)
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
inference_ops.fused_add_rms_layernorm(x, residual, weight, eps)
y_cuda = x
x = x_copy + residual_copy
y_llama = rms_norm.forward(x).to(dtype)
assert y_cuda.shape == y_llama.shape
assert torch.allclose(y_cuda, y_llama, atol=1e-5, rtol=1e-3)
assert torch.allclose(x, residual, atol=1e-5, rtol=1e-3)
if __name__ == "__main__":
test_rms_layernorm(16, 512)

View File

@@ -0,0 +1,91 @@
import pytest
import torch
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb
from colossalai.kernel.kernel_loader import InferenceOpsLoader
inference_ops = InferenceOpsLoader().load()
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2
from tests.test_infer.test_ops.triton.test_rotary_embdding_unpad import torch_rotary_emb
@pytest.mark.parametrize("BATCH_SIZE", [4])
@pytest.mark.parametrize("SEQ_LEN", [64])
@pytest.mark.parametrize("H", [32])
@pytest.mark.parametrize("D", [64])
@pytest.mark.parametrize("dtype", [torch.float16])
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data
block_size = 32
max_blocks_per_sequence = (TOTAL_TOKENS + block_size - 1) // block_size
q_shape = (TOTAL_TOKENS, H, D)
q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda")
k_shape = (TOTAL_TOKENS, H, D)
k = -2.3 + 0.5 * torch.randn(k_shape, dtype=dtype, device="cuda")
cos_shape = (TOTAL_TOKENS, D // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
cache_shape = (BATCH_SIZE * max_blocks_per_sequence, H, block_size, D)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v = torch.randn_like(k)
v_cache = torch.zeros_like(k_cache)
past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_blocks_per_sequence, block_size
)
new_k = torch.randn((BATCH_SIZE, H, D), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
kv_seq_lengths = past_kv_seq_lengths + 1
block_tables = block_tables.to(device="cuda")
q_ref = torch_rotary_emb(new_q, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
k_ref = torch_rotary_emb(new_k, cos[:BATCH_SIZE], sin[:BATCH_SIZE])
new_q_copy = new_q.clone()
new_k_copy = new_k.clone()
inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
)
inference_ops.rotary_embedding(new_q_copy, new_k_copy, cos, sin)
past_kv_seq_len = kv_seq_lengths - 1
target_block_ids = block_tables[range(0, block_tables.size(0)), past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
k_target = k_cache[target_block_ids, :, offsets_in_block, :].squeeze()
k_source = new_k_copy.squeeze()
v_target = v_cache[target_block_ids, :, offsets_in_block, :].squeeze()
v_source = new_v.squeeze()
assert torch.allclose(new_q, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(k_target, k_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(new_q_copy, q_ref, atol=1e-6, rtol=1e-6)
assert torch.allclose(new_k_copy, k_ref, atol=1e-6, rtol=1e-6)
assert k_target.shape == k_source.shape
assert torch.allclose(k_target, k_source, atol=1e-6, rtol=1e-6)
assert v_target.shape == v_source.shape
assert torch.equal(v_target, v_source)
if __name__ == "__main__":
test_rotary_emb(16, 512, 4, 128, torch.float16)

View File

@@ -0,0 +1,33 @@
import pytest
import torch
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
inference_ops = InferenceOpsLoader().load()
@pytest.mark.parametrize("SHAPE_X", [2])
@pytest.mark.parametrize("SHAPE_Y", [64])
@pytest.mark.parametrize("SHAPE_Z", [11008])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16])
def test_silu_and_mul(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype):
torch.manual_seed(5)
device = get_current_device()
ref_input = torch.randn(SHAPE_X, SHAPE_Y, SHAPE_Z, dtype=dtype, device=device)
origin_input = ref_input.clone()
act_out = torch.nn.functional.silu(ref_input[0], inplace=True)
ref_out = act_out * ref_input[1]
origin_out = inference_ops.silu_and_mul(origin_input)
if dtype == torch.float32:
assert torch.allclose(origin_out, ref_out, atol=1e-5, rtol=1e-5)
else:
assert torch.allclose(origin_out, ref_out, atol=1e-3, rtol=1e-3)
if __name__ == "__main__":
test_silu_and_mul(2, 64, 11008, torch.float32)
test_silu_and_mul(2, 64, 11008, torch.float16)