[Refactor] Integrated some lightllm kernels into token-attention (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
Cuiqing Li
2023-10-19 22:22:47 +08:00
committed by GitHub
parent 11009103be
commit 3a41e8304e
20 changed files with 160 additions and 1555 deletions

View File

@@ -1,63 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
xv = xv.view(bs, seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
prob = torch.softmax(logics, dim=1)
prob = prob.view(bs, seqlen, num_head, 1)
return torch.sum(prob * xv, dim=1, keepdim=False)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test():
Z, head_num, seq_len, head_dim = 2, 32, 2048, 128
dtype = torch.float16
# attn out: 2,4096
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda")
max_kv_cache_len = seq_len
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
other_kv_index = 2048
kv_cache_seq_len[:] = seq_len
kv_cache_start_loc[0] = 0
kv_cache_start_loc[1] = seq_len
for i in range(Z):
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
Llama2TokenAttentionForwards.token_attn(
q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index
)
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
assert torch.allclose(torch_out, o, atol=1e-3, rtol=0)
if __name__ == "__main__":
test()

View File

@@ -1,55 +0,0 @@
# Adapted from ModelTC https://github.com/ModelTC/lightllm
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_rotary_emb(x, cos, sin):
seq_len, h, dim = x.shape
x0 = x[:, :, 0 : dim // 2]
x1 = x[:, :, dim // 2 : dim]
cos = cos.view((seq_len, 1, dim // 2))
sin = sin.view((seq_len, 1, dim // 2))
o0 = x0 * cos - x1 * sin
o1 = x0 * sin + x1 * cos
return torch.cat((o0, o1), dim=-1)
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_rotary_emb():
SEQ_LEN = 1
HEAD_NUM = 32
HEAD_DIM = 128
dtype = torch.half
# create data
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
# forward pass
y_torch = torch_rotary_emb(x, cos, sin)
rotary_embedding_fwd(x, cos, sin)
y_triton = x
# compare
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_rotary_emb()

View File

@@ -1,74 +0,0 @@
import math
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
xq = xq.view(bs, 1, num_head, head_dim)
xk = xk.view(bs, seqlen, num_head, head_dim)
keys = xk
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
scores = (
(torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1)
)
return scores
def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
xq = xq.view(1, num_head, head_dim)
xk = xk.view(seqlen, num_head, head_dim)
logics = torch.sum(xq * xk, dim=-1, keepdim=False)
logics = logics.transpose(0, 1) / math.sqrt(head_dim)
return logics
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_attn_1():
pass
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
dtype = torch.float16
q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out.squeeze()
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_attn_1()

View File

@@ -1,63 +0,0 @@
import pytest
import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
def torch_attn(V, P, bs, seqlen, num_head, head_dim):
V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
attn_out = torch.matmul(P, V)
return attn_out
@pytest.mark.skipif(
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
)
def test_token_attn_2():
pass
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
dtype = torch.float16
V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
Prob = (
torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
.normal_(mean=0.4, std=0.2)
.reshape(head_num, batch_size, seq_len)
.softmax(-1)
.reshape(head_num, batch_size * seq_len)
)
attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
for i in range(batch_size):
kv_cache_start_loc[i] = i * seq_len
kv_cache_seq_len[i] = seq_len
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
o = attn_out
print("max ", torch.max(torch.abs(torch_out - o)))
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
if __name__ == "__main__":
test_token_attn_2()

View File

@@ -3,16 +3,13 @@ import torch
from packaging import version
try:
pass
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
print("please install triton from https://github.com/openai/triton")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):