1
0
mirror of https://github.com/hpcaitech/ColossalAI.git synced 2025-08-26 19:50:53 +00:00
ColossalAI/tests/test_smoothquant/test_smoothquant_linear.py
Xu Kai 611a5a80ca
[inference] Add smmoothquant for llama ()
* [inference] add int8 rotary embedding kernel for smoothquant ()

* [inference] add smoothquant llama attention ()

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  ()

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant ()

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama ()

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant ()

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes ()

* rafactor code

* add license

* add torch-int and smoothquant license
2023-10-16 11:28:44 +08:00

40 lines
1.0 KiB
Python

import warnings
import pytest
import torch
try:
from colossalai.kernel.op_builder.smoothquant import SmoothquantBuilder
smoothquant_cuda = SmoothquantBuilder().load()
HAS_SMOOTHQUANT_CUDA = True
except:
warnings.warn("CUDA smoothquant linear is not installed")
HAS_SMOOTHQUANT_CUDA = False
@pytest.mark.skipif(
not HAS_SMOOTHQUANT_CUDA,
reason="smoothquant linear not installed properly",
)
def test_linear():
a = torch.randint(-127, 127, (128, 512), dtype=torch.int8, device="cuda")
b = torch.randint(-127, 127, (512, 256), dtype=torch.int8, device="cuda")
c = torch.rand(256, dtype=torch.float, device="cuda")
alpha = 1 / 127
beta = 1.0
torch_out = torch.mm(a.to(torch.float) * alpha, b.to(torch.float)) + c
silu = torch.nn.SiLU()
torch_out = silu(torch_out)
b = b.transpose(0, 1).contiguous()
cuda_out = smoothquant_cuda.linear_silu_a8_w8_bfp32_ofp32(a, b, c, alpha, beta)
assert torch.allclose(torch_out, cuda_out, rtol=1e-02, atol=1e-02)
if __name__ == "__main__":
test_linear()