Allow to compute when bsz == 1

This commit is contained in:
hxwang 2025-03-17 17:13:55 +08:00
parent 16e46efd79
commit b42472859a
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8

View File

@ -800,7 +800,11 @@ class _LinearFp8(torch.autograd.Function):
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0]
)
if isinstance(out, tuple):
out = out[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
@staticmethod
@ -814,7 +818,11 @@ class _LinearFp8(torch.autograd.Function):
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0]
)
if isinstance(x_grad, tuple):
x_grad = x_grad[0]
w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(),
ctx.x_fp8.t().contiguous().t(),
@ -822,7 +830,11 @@ class _LinearFp8(torch.autograd.Function):
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0]
)
if isinstance(w_grad, tuple):
w_grad = w_grad[0]
bias_grad = None
if ctx.has_bias:
bias_grad = out_grad.sum(0)
@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
"""
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
if not (x.dim() == 2 and w.dim() == 2):
raise ValueError("Batched fp8 deep_gemm is not supported")
has_batch_dim = False
if x.dim() == 3:
has_batch_dim = True
if x.size(1) != 1:
raise ValueError(f"Batched fp8 deep_gemm is not supported, found x shape: {x.shape}")
x = x.squeeze(1)
ctx.has_batch_dim = has_batch_dim
# x: (m, k), w: (n, k)
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
(m, k), (n, _) = x.shape, w.shape
@ -848,12 +866,17 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
ctx.mnk = (m, n, k)
if has_batch_dim:
out = out.unsqueeze(1)
return out
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# o_grad: (m, n)
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
if ctx.has_batch_dim:
o_grad = o_grad.squeeze(1)
m, n, k = ctx.mnk
o_per_tok = per_token_cast_to_fp8(o_grad)
@ -864,6 +887,9 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
if ctx.has_batch_dim:
x_grad = x_grad.unsqueeze(1)
return x_grad, w_grad