mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[feature] support bias
This commit is contained in:
parent
c832927d4e
commit
9ef62832fd
@ -867,10 +867,13 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
|
||||
return x_grad, w_grad
|
||||
|
||||
|
||||
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
|
||||
def linear_fp8_deep_gemm(
|
||||
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
o = _LinearFp8DeepGemm.apply(input, weight)
|
||||
if bias is not None:
|
||||
raise ValueError("bias is not supported in deep_gemm")
|
||||
return _LinearFp8DeepGemm.apply(input, weight)
|
||||
o += bias
|
||||
return o
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
||||
|
@ -16,15 +16,16 @@ def test_fp8_linear():
|
||||
# create tensors
|
||||
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
ref_w = w.clone().detach().requires_grad_()
|
||||
ref_x = x.clone().detach().requires_grad_()
|
||||
|
||||
out = linear_fp8_deep_gemm(x, w)
|
||||
out = linear_fp8_deep_gemm(x, w, bias)
|
||||
assert out.shape == x.shape[:-1] + (n,)
|
||||
out.sum().backward()
|
||||
ref_out = F.linear(ref_x, ref_w)
|
||||
ref_out = F.linear(ref_x, ref_w, bias)
|
||||
ref_out.sum().backward()
|
||||
|
||||
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
||||
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
||||
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
||||
assert_close(out, ref_out)
|
||||
assert_close(x.grad, ref_x.grad)
|
||||
assert_close(w.grad, ref_w.grad)
|
||||
|
Loading…
Reference in New Issue
Block a user