mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-01 21:26:42 +00:00
[feature] support bias
This commit is contained in:
parent
c832927d4e
commit
9ef62832fd
@ -867,10 +867,13 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
|
|||||||
return x_grad, w_grad
|
return x_grad, w_grad
|
||||||
|
|
||||||
|
|
||||||
def linear_fp8_deep_gemm(input: torch.Tensor, weight: torch.Tensor, bias: None = None) -> torch.Tensor:
|
def linear_fp8_deep_gemm(
|
||||||
|
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
o = _LinearFp8DeepGemm.apply(input, weight)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
raise ValueError("bias is not supported in deep_gemm")
|
o += bias
|
||||||
return _LinearFp8DeepGemm.apply(input, weight)
|
return o
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
||||||
|
@ -16,15 +16,16 @@ def test_fp8_linear():
|
|||||||
# create tensors
|
# create tensors
|
||||||
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
x = torch.rand((m, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
w = torch.rand((n, k), device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
|
bias = torch.rand(n, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
ref_w = w.clone().detach().requires_grad_()
|
ref_w = w.clone().detach().requires_grad_()
|
||||||
ref_x = x.clone().detach().requires_grad_()
|
ref_x = x.clone().detach().requires_grad_()
|
||||||
|
|
||||||
out = linear_fp8_deep_gemm(x, w)
|
out = linear_fp8_deep_gemm(x, w, bias)
|
||||||
assert out.shape == x.shape[:-1] + (n,)
|
assert out.shape == x.shape[:-1] + (n,)
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
ref_out = F.linear(ref_x, ref_w)
|
ref_out = F.linear(ref_x, ref_w, bias)
|
||||||
ref_out.sum().backward()
|
ref_out.sum().backward()
|
||||||
|
|
||||||
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
assert_close(out, ref_out)
|
||||||
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
assert_close(x.grad, ref_x.grad)
|
||||||
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
assert_close(w.grad, ref_w.grad)
|
||||||
|
Loading…
Reference in New Issue
Block a user