mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 11:32:10 +00:00
[inference] Add smmoothquant for llama (#4904)
* [inference] add int8 rotary embedding kernel for smoothquant (#4843) * [inference] add smoothquant llama attention (#4850) * add smoothquant llama attention * remove uselss code * remove useless code * fix import error * rename file name * [inference] add silu linear fusion for smoothquant llama mlp (#4853) * add silu linear * update skip condition * catch smoothquant cuda lib exception * prcocess exception for tests * [inference] add llama mlp for smoothquant (#4854) * add llama mlp for smoothquant * fix down out scale * remove duplicate lines * add llama mlp check * delete useless code * [inference] add smoothquant llama (#4861) * add smoothquant llama * fix attention accuracy * fix accuracy * add kv cache and save pretrained * refactor example * delete smooth * refactor code * [inference] add smooth function and delete useless code for smoothquant (#4895) * add smooth function and delete useless code * update datasets * remove duplicate import * delete useless file * refactor codes (#4902) * rafactor code * add license * add torch-int and smoothquant license
This commit is contained in:
39
tests/test_smoothquant/test_smoothquant_linear.py
Normal file
39
tests/test_smoothquant/test_smoothquant_linear.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
try:
|
||||
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
|
||||
|
||||
smoothquant_cuda = SmoothquantBuilder().load()
|
||||
HAS_SMOOTHQUANT_CUDA = True
|
||||
except:
|
||||
warnings.warn("CUDA smoothquant linear is not installed")
|
||||
HAS_SMOOTHQUANT_CUDA = False
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not HAS_SMOOTHQUANT_CUDA,
|
||||
reason="smoothquant linear not installed properly",
|
||||
)
|
||||
def test_linear():
|
||||
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
|
||||
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
|
||||
c = torch.rand(256, dtype=torch.float, device="cuda")
|
||||
|
||||
alpha = 1 / 127
|
||||
beta = 1.0
|
||||
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
|
||||
|
||||
silu = torch.nn.SiLU()
|
||||
torch_out = silu(torch_out)
|
||||
|
||||
b = b.transpose(0, 1).contiguous()
|
||||
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
|
||||
|
||||
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_linear()
|
Reference in New Issue
Block a user