mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[Refactor] Integrated some lightllm kernels into token-attention (#4946)
* add some req for inference * clean codes * add codes * add some lightllm deps * clean codes * hello * delete rms files * add some comments * add comments * add doc * add lightllm deps * add lightllm cahtglm2 kernels * add lightllm cahtglm2 kernels * replace rotary embedding with lightllm kernel * add some commnets * add some comments * add some comments * add * replace fwd kernel att1 * fix a arg * add * add * fix token attention * add some comments * clean codes * modify comments * fix readme * fix bug * fix bug --------- Co-authored-by: cuiqing.li <lixx336@gmail.com> Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
This commit is contained in:
@@ -1,63 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import Llama2TokenAttentionForwards
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
|
||||
logics = torch.sum(xq * xk, dim=3, keepdim=False) * 1 / (head_dim**0.5)
|
||||
prob = torch.softmax(logics, dim=1)
|
||||
prob = prob.view(bs, seqlen, num_head, 1)
|
||||
|
||||
return torch.sum(prob * xv, dim=1, keepdim=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test():
|
||||
Z, head_num, seq_len, head_dim = 2, 32, 2048, 128
|
||||
dtype = torch.float16
|
||||
|
||||
# attn out: 2,4096
|
||||
q = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2)
|
||||
v = torch.empty((Z * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2)
|
||||
o = torch.empty((Z, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
max_kv_cache_len = seq_len
|
||||
kv_cache_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((Z, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda")
|
||||
other_kv_index = 2048
|
||||
|
||||
kv_cache_seq_len[:] = seq_len
|
||||
kv_cache_start_loc[0] = 0
|
||||
kv_cache_start_loc[1] = seq_len
|
||||
|
||||
for i in range(Z):
|
||||
kv_cache_loc[i, :] = torch.arange(i * seq_len, (i + 1) * seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
Llama2TokenAttentionForwards.token_attn(
|
||||
q, k, v, o, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, max_kv_cache_len, other_kv_index
|
||||
)
|
||||
torch_out = torch_att(q, k, v, Z, seq_len, head_num, head_dim)
|
||||
assert torch.allclose(torch_out, o, atol=1e-3, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
@@ -1,55 +0,0 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.rotary_embedding_kernel import rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0 : dim // 2]
|
||||
x1 = x[:, :, dim // 2 : dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_rotary_emb():
|
||||
SEQ_LEN = 1
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
dtype = torch.half
|
||||
# create data
|
||||
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
# forward pass
|
||||
y_torch = torch_rotary_emb(x, cos, sin)
|
||||
rotary_embedding_fwd(x, cos, sin)
|
||||
y_triton = x
|
||||
# compare
|
||||
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb()
|
@@ -1,74 +0,0 @@
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_1
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_attn(xq, xk, bs, seqlen, num_head, head_dim):
|
||||
xq = xq.view(bs, 1, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
keys = xk
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
scores = (
|
||||
(torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)).squeeze().transpose(0, 1).reshape(num_head, -1)
|
||||
)
|
||||
return scores
|
||||
|
||||
|
||||
def torch_attn_1(xq, xk, seqlen, num_head, head_dim):
|
||||
xq = xq.view(1, num_head, head_dim)
|
||||
xk = xk.view(seqlen, num_head, head_dim)
|
||||
logics = torch.sum(xq * xk, dim=-1, keepdim=False)
|
||||
|
||||
logics = logics.transpose(0, 1) / math.sqrt(head_dim)
|
||||
return logics
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_attn_1():
|
||||
pass
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
||||
|
||||
dtype = torch.float16
|
||||
|
||||
q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
|
||||
attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
|
||||
|
||||
b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
||||
|
||||
torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
|
||||
o = attn_out.squeeze()
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_attn_1()
|
@@ -1,63 +0,0 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attn_fwd_2
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_attn(V, P, bs, seqlen, num_head, head_dim):
|
||||
V = V.view(bs, seqlen, num_head, head_dim).transpose(1, 2)
|
||||
P = P.reshape(num_head, bs, 1, seqlen).transpose(0, 1)
|
||||
attn_out = torch.matmul(P, V)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_token_attn_2():
|
||||
pass
|
||||
|
||||
batch_size, seq_len, head_num, head_dim = 17, 1025, 12, 128
|
||||
dtype = torch.float16
|
||||
|
||||
V = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=10)
|
||||
Prob = (
|
||||
torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")
|
||||
.normal_(mean=0.4, std=0.2)
|
||||
.reshape(head_num, batch_size, seq_len)
|
||||
.softmax(-1)
|
||||
.reshape(head_num, batch_size * seq_len)
|
||||
)
|
||||
attn_out = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
|
||||
kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
kv_cache_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
|
||||
for i in range(batch_size):
|
||||
kv_cache_start_loc[i] = i * seq_len
|
||||
kv_cache_seq_len[i] = seq_len
|
||||
kv_cache_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
|
||||
|
||||
token_attn_fwd_2(Prob, V, attn_out, kv_cache_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)
|
||||
|
||||
torch_out = torch_attn(V, Prob, batch_size, seq_len, head_num, head_dim).squeeze()
|
||||
o = attn_out
|
||||
print("max ", torch.max(torch.abs(torch_out - o)))
|
||||
print("mean ", torch.mean(torch.abs(torch_out - o)))
|
||||
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_token_attn_2()
|
@@ -3,16 +3,13 @@ import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
pass
|
||||
|
||||
from colossalai.kernel.triton.token_attention_kernel import token_attention_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) >= version.parse("11.6")
|
||||
|
||||
|
||||
def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
|
Reference in New Issue
Block a user