mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 12:43:55 +00:00
32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.testing import assert_close
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.quantization.fp8 import linear_fp8_deep_gemm
|
|
from colossalai.utils import get_current_device
|
|
|
|
m, k, n = 128, 384, 256
|
|
DTYPE = torch.bfloat16
|
|
|
|
|
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
|
def test_fp8_linear():
|
|
# create tensors
|
|
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
ref_w = w.clone().detach().requires_grad_()
|
|
ref_x = x.clone().detach().requires_grad_()
|
|
|
|
out = linear_fp8_deep_gemm(x, w, bias)
|
|
assert out.shape == x.shape[:-1] + (n,)
|
|
out.sum().backward()
|
|
ref_out = F.linear(ref_x, ref_w, bias)
|
|
ref_out.sum().backward()
|
|
|
|
assert_close(out, ref_out)
|
|
assert_close(x.grad, ref_x.grad)
|
|
assert_close(w.grad, ref_w.grad)
|