mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license
This commit is contained in:
136
tests/test_smoothquant/test_llama_attention.py
Normal file
136
tests/test_smoothquant/test_llama_attention.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
try:
|
||||
from colossalai.inference.quant.smoothquant.models import LLamaSmoothquantAttention
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except ImportError:
|
||||
HAS_TORCH_INT = False
|
||||
print("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim):
|
||||
"""
|
||||
adapted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253
|
||||
"""
|
||||
xq = xq.view(bs, seqlen, num_head, head_dim)
|
||||
xk = xk.view(bs, seqlen, num_head, head_dim)
|
||||
xv = xv.view(bs, seqlen, num_head, head_dim)
|
||||
mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda()
|
||||
mask[mask == 0.0] = -100000000.0
|
||||
mask = mask.repeat(bs, num_head, 1, 1)
|
||||
keys = xk
|
||||
values = xv
|
||||
xq = xq.transpose(1, 2)
|
||||
keys = keys.transpose(1, 2)
|
||||
values = values.transpose(1, 2)
|
||||
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim)
|
||||
scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq)
|
||||
output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON or not HAS_TORCH_INT,
|
||||
reason="triton requires cuda version to be higher than 11.4 or not install torch_int",
|
||||
)
|
||||
def test_llama_context_attention():
|
||||
head_num = 2
|
||||
seq_len = 32
|
||||
head_dim = 64
|
||||
dtype = torch.float
|
||||
hidden_size = head_num * head_dim
|
||||
|
||||
smooth_attn = LLamaSmoothquantAttention(head_num * head_dim, head_num)
|
||||
|
||||
smooth_attn.q_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.k_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.v_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.out_proj.weight = torch.ones(hidden_size, hidden_size, device="cuda").to(torch.int8)
|
||||
smooth_attn.out_proj.weight[:, 1:hidden_size] = torch.zeros(hidden_size - 1, device="cuda").to(torch.int8)
|
||||
|
||||
qkv_weight_scale = 1.0
|
||||
|
||||
ones = torch.ones(hidden_size, hidden_size, dtype=torch.float, device="cuda")
|
||||
|
||||
smooth_attn = smooth_attn.to("cuda")
|
||||
|
||||
input = torch.randint(-20, 20, (1, seq_len, head_num * head_dim), dtype=torch.int8, device="cuda")
|
||||
input_scale = 1 / 20.0
|
||||
|
||||
output = torch.matmul(input.to(torch.float) * input_scale, ones)
|
||||
qkv_max_out = torch.max(torch.abs(output)) / 127
|
||||
smooth_attn.q_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
smooth_attn.k_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
smooth_attn.v_proj.a = torch.tensor(input_scale * qkv_weight_scale / qkv_max_out)
|
||||
|
||||
q = smooth_attn.q_proj(input)
|
||||
k = smooth_attn.k_proj(input)
|
||||
v = smooth_attn.v_proj(input)
|
||||
|
||||
cos_shape = (seq_len, head_dim // 2)
|
||||
cos = torch.ones(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = torch.zeros(cos_shape, dtype=dtype, device="cuda")
|
||||
in_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||
out_scale = torch.tensor([qkv_max_out], device="cuda")
|
||||
int8_rotary_embedding_fwd(q.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||
int8_rotary_embedding_fwd(k.view(-1, head_num, head_dim), cos, sin, in_scale.item(), out_scale.item())
|
||||
|
||||
q = q.to(torch.float) * out_scale
|
||||
k = k.to(torch.float) * out_scale
|
||||
v = v.to(torch.float) * out_scale
|
||||
torch_out = torch_context_attention(q.clone(), k.clone(), v.clone(), 1, seq_len, head_num, head_dim)
|
||||
attn_out_max = torch.max(torch.abs(torch_out)) / 127
|
||||
|
||||
output = torch.matmul(torch_out.view(-1, seq_len, head_num * head_dim), ones)
|
||||
smooth_attn.q_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.k_output_scale = torch.tensor(qkv_max_out)
|
||||
|
||||
smooth_attn.v_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.q_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||
smooth_attn.k_rotary_output_scale = torch.tensor(qkv_max_out)
|
||||
|
||||
smooth_attn.attn_output_scale = torch.tensor(attn_out_max)
|
||||
smooth_attn.out_proj.a = torch.tensor([attn_out_max])
|
||||
|
||||
torch_out = (
|
||||
(torch_out / smooth_attn.attn_output_scale)
|
||||
.round()
|
||||
.clamp(-128, 127)
|
||||
.to(torch.int8)
|
||||
.view(-1, seq_len, head_num * head_dim)
|
||||
)
|
||||
|
||||
torch_out = smooth_attn.out_proj(torch_out)
|
||||
torch_out = torch_out.to(torch.float)
|
||||
|
||||
smooth_attn = smooth_attn.to("cuda")
|
||||
smooth_out, _, _ = smooth_attn(input, (cos, sin))
|
||||
smooth_out = smooth_out.to(torch.float)
|
||||
|
||||
assert torch.allclose(
|
||||
torch_out.cpu(), smooth_out.cpu(), rtol=1e-1, atol=1e-1
|
||||
), "outputs from triton and torch are not matched"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_context_attention()
|
84
tests/test_smoothquant/test_llama_mlp.py
Normal file
84
tests/test_smoothquant/test_llama_mlp.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except:
|
||||
warnings.warn("CUDA smoothquant linear is not installed")
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
|
||||
|
||||
try:
|
||||
from colossalai.inference.quant.smoothquant.models import LlamaSmoothquantMLP
|
||||
|
||||
HAS_TORCH_INT = True
|
||||
except:
|
||||
HAS_TORCH_INT = False
|
||||
warnings.warn("Please install torch_int from https://github.com/Guangxuan-Xiao/torch-int")
|
||||
|
||||
|
||||
CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_llama_mlp(gate_proj, up_proj, down_proj, x):
|
||||
gate_out = torch.mm(x, gate_proj)
|
||||
silu = torch.nn.SiLU()
|
||||
gate_out = silu(gate_out)
|
||||
up_out = torch.mm(x, up_proj)
|
||||
|
||||
o_out = gate_out * up_out
|
||||
|
||||
max_up = torch.max(torch.abs(o_out))
|
||||
min_up = torch.min(torch.abs(o_out))
|
||||
|
||||
torch_out = torch.mm(o_out, down_proj)
|
||||
|
||||
return (torch_out, max_up, min_up)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not CUDA_SUPPORT or not HAS_SMOOTHQUANT_CUDA or not HAS_TORCH_INT,
|
||||
reason="smoothquant linear not installed properly or not install torch_int",
|
||||
)
|
||||
def test_llama_mlp():
|
||||
hidden_size = 256
|
||||
intermediate_size = 512
|
||||
|
||||
smooth_mlp = LlamaSmoothquantMLP(intermediate_size, hidden_size)
|
||||
|
||||
smooth_mlp.gate_proj.weight = torch.ones((intermediate_size, hidden_size), dtype=torch.int8, device="cuda")
|
||||
|
||||
smooth_mlp.up_proj.weight = torch.randint(
|
||||
-10, 10, (intermediate_size, hidden_size), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
smooth_mlp.down_proj.weight = torch.randint(
|
||||
-10, 10, (hidden_size, intermediate_size), dtype=torch.int8, device="cuda"
|
||||
)
|
||||
|
||||
x = torch.ones((1, 256), dtype=torch.int8, device="cuda")
|
||||
|
||||
torch_out, max_inter, min_inter = torch_llama_mlp(
|
||||
smooth_mlp.gate_proj.weight.transpose(0, 1).to(torch.float) / hidden_size,
|
||||
smooth_mlp.up_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||
smooth_mlp.down_proj.weight.transpose(0, 1).to(torch.float) / 127,
|
||||
x.to(torch.float),
|
||||
)
|
||||
|
||||
smooth_mlp.down_proj_input_scale = torch.tensor(max_inter.item() / 127)
|
||||
smooth_mlp.gate_proj.a = torch.tensor(1 / hidden_size)
|
||||
smooth_mlp.up_proj.a = torch.tensor(1 / 127)
|
||||
smooth_mlp.down_proj.a = torch.tensor(1 / 127 * (max_inter.item() / 127))
|
||||
|
||||
smooth_out = smooth_mlp(x)
|
||||
|
||||
assert torch.allclose(torch_out, smooth_out, rtol=1e-02, atol=1e-01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_llama_mlp()
|
39
tests/test_smoothquant/test_smoothquant_linear.py
Normal file
39
tests/test_smoothquant/test_smoothquant_linear.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except:
|
||||
warnings.warn("CUDA smoothquant linear is not installed")
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not HAS_SMOOTHQUANT_CUDA,
|
||||
reason="smoothquant linear not installed properly",
|
||||
)
|
||||
def test_linear():
|
||||
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
|
||||
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
|
||||
c = torch.rand(256, dtype=torch.float, device="cuda")
|
||||
|
||||
alpha = 1 / 127
|
||||
beta = 1.0
|
||||
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
|
||||
|
||||
silu = torch.nn.SiLU()
|
||||
torch_out = silu(torch_out)
|
||||
|
||||
b = b.transpose(0, 1).contiguous()
|
||||
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
|
||||
|
||||
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_linear()
|
59
tests/test_smoothquant/test_sq_rotary_embedding.py
Normal file
59
tests/test_smoothquant/test_sq_rotary_embedding.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Adapted from ModelTC https://github.com/ModelTC/lightllm
|
||||
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai.kernel.triton import int8_rotary_embedding_fwd
|
||||
|
||||
HAS_TRITON = True
|
||||
except ImportError:
|
||||
HAS_TRITON = False
|
||||
print("please install triton from https://github.com/openai/triton")
|
||||
|
||||
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
|
||||
|
||||
|
||||
def torch_rotary_emb(x, cos, sin):
|
||||
seq_len, h, dim = x.shape
|
||||
x0 = x[:, :, 0 : dim // 2]
|
||||
x1 = x[:, :, dim // 2 : dim]
|
||||
cos = cos.view((seq_len, 1, dim // 2))
|
||||
sin = sin.view((seq_len, 1, dim // 2))
|
||||
o0 = x0 * cos - x1 * sin
|
||||
o1 = x0 * sin + x1 * cos
|
||||
return torch.cat((o0, o1), dim=-1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
|
||||
)
|
||||
def test_rotary_emb():
|
||||
SEQ_LEN = 1
|
||||
HEAD_NUM = 32
|
||||
HEAD_DIM = 128
|
||||
dtype = torch.float
|
||||
# create data
|
||||
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
|
||||
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
|
||||
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
|
||||
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
|
||||
# forward pass
|
||||
y_torch = torch_rotary_emb(x, cos, sin)
|
||||
|
||||
input_scale = torch.max(torch.abs(x)) / 127
|
||||
output_scale = torch.max(torch.abs(y_torch)) / 127
|
||||
|
||||
x = x / input_scale
|
||||
x = x.to(torch.int8)
|
||||
|
||||
int8_rotary_embedding_fwd(x, cos, sin, input_scale.item(), output_scale.item())
|
||||
y_triton = x.to(torch.float) * output_scale
|
||||
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_rotary_emb()
|
Reference in New Issue
Block a user