ColossalAI/tests/test_fp8/test_fp8_deepgemm.py
2025-03-13 17:48:13 +08:00

32 lines
1.1 KiB
Python

import pytest
import torch
import torch.nn.functional as F
from torch.testing import assert_close
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8_deep_gemm
from colossalai.utils import get_current_device
m, k, n = 128, 384, 256
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_linear():
# create tensors
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_w = w.clone().detach().requires_grad_()
ref_x = x.clone().detach().requires_grad_()
out = linear_fp8_deep_gemm(x, w, bias)
assert out.shape == x.shape[:-1] + (n,)
out.sum().backward()
ref_out = F.linear(ref_x, ref_w, bias)
ref_out.sum().backward()
assert_close(out, ref_out)
assert_close(x.grad, ref_x.grad)
assert_close(w.grad, ref_w.grad)