diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index d899964e5..bd0358cbd 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -867,10 +867,13 @@ class _LinearFp8DeepGemm(torch.autograd.Function): return x_grad, w_grad -def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor: +def linear_fp8_deep_gemm( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: + o = _LinearFp8DeepGemm.apply(input, weight) if bias is not None: - raise ValueError("bias is not supported in deep_gemm") - return _LinearFp8DeepGemm.apply(input, weight) + o += bias + return o @torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel) diff --git a/tests/test_fp8/test_fp8_deepgemm.py b/tests/test_fp8/test_fp8_deepgemm.py index 47aed247a..5f1de4f9c 100644 --- a/tests/test_fp8/test_fp8_deepgemm.py +++ b/tests/test_fp8/test_fp8_deepgemm.py @@ -16,15 +16,16 @@ def test_fp8_linear(): # create tensors x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True) + bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True) ref_w = w.clone().detach().requires_grad_() ref_x = x.clone().detach().requires_grad_() - out = linear_fp8_deep_gemm(x, w) + out = linear_fp8_deep_gemm(x, w, bias) assert out.shape == x.shape[:-1] + (n,) out.sum().backward() - ref_out = F.linear(ref_x, ref_w) + ref_out = F.linear(ref_x, ref_w, bias) ref_out.sum().backward() - assert_close(out, ref_out, rtol=0.2, atol=0.1) - assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1) - assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1) + assert_close(out, ref_out) + assert_close(x.grad, ref_x.grad) + assert_close(w.grad, ref_w.grad)