[feature] support bias

This commit is contained in:
hxwang 2025-03-13 17:48:13 +08:00
parent c832927d4e
commit 9ef62832fd
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
2 changed files with 12 additions and 8 deletions

View File

@ -867,10 +867,13 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
return x_grad, w_grad
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
def linear_fp8_deep_gemm(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
o = _LinearFp8DeepGemm.apply(input, weight)
if bias is not None:
raise ValueError("bias is not supported in deep_gemm")
return _LinearFp8DeepGemm.apply(input, weight)
o += bias
return o
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)

View File

@ -16,15 +16,16 @@ def test_fp8_linear():
# create tensors
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True)
ref_w = w.clone().detach().requires_grad_()
ref_x = x.clone().detach().requires_grad_()
out = linear_fp8_deep_gemm(x, w)
out = linear_fp8_deep_gemm(x, w, bias)
assert out.shape == x.shape[:-1] + (n,)
out.sum().backward()
ref_out = F.linear(ref_x, ref_w)
ref_out = F.linear(ref_x, ref_w, bias)
ref_out.sum().backward()
assert_close(out, ref_out, rtol=0.2, atol=0.1)
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
assert_close(out, ref_out)
assert_close(x.grad, ref_x.grad)
assert_close(w.grad, ref_w.grad)