diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bd0358cbd..4b7c48354 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -800,7 +800,11 @@ class _LinearFp8(torch.autograd.Function): scale_a=inv_scale_x, scale_b=inv_scale_w, use_fast_accum=True, - )[0] + ) + + if isinstance(out, tuple): + out = out[0] + return out.reshape(*ctx.x_shape[:-1], w.shape[0]) @staticmethod @@ -814,7 +818,11 @@ class _LinearFp8(torch.autograd.Function): scale_a=out_grad_scale, scale_b=ctx.inv_scale_w, use_fast_accum=True, - )[0] + ) + + if isinstance(x_grad, tuple): + x_grad = x_grad[0] + w_grad = torch._scaled_mm( out_grad_fp8.t().contiguous(), ctx.x_fp8.t().contiguous().t(), @@ -822,7 +830,11 @@ class _LinearFp8(torch.autograd.Function): scale_a=out_grad_scale, scale_b=ctx.inv_scale_x, use_fast_accum=True, - )[0] + ) + + if isinstance(w_grad, tuple): + w_grad = w_grad[0] + bias_grad = None if ctx.has_bias: bias_grad = out_grad.sum(0) @@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function): """ def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: - if not (x.dim() == 2 and w.dim() == 2): - raise ValueError("Batched fp8 deep_gemm is not supported") + has_batch_dim = False + if x.dim() == 3: + has_batch_dim = True + if x.size(1) != 1: + raise ValueError(f"Batched fp8 deep_gemm is not supported, found x shape: {x.shape}") + x = x.squeeze(1) + ctx.has_batch_dim = has_batch_dim + # x: (m, k), w: (n, k) # x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k)) (m, k), (n, _) = x.shape, w.shape @@ -848,12 +866,17 @@ class _LinearFp8DeepGemm(torch.autograd.Function): ctx.w_t_per_plk = per_block_cast_to_fp8(w.t()) ctx.x_t_per_blk = per_block_cast_to_fp8(x.t()) ctx.mnk = (m, n, k) + if has_batch_dim: + out = out.unsqueeze(1) return out def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # o_grad: (m, n) # x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n)) # w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m)) + if ctx.has_batch_dim: + o_grad = o_grad.squeeze(1) + m, n, k = ctx.mnk o_per_tok = per_token_cast_to_fp8(o_grad) @@ -864,6 +887,9 @@ class _LinearFp8DeepGemm(torch.autograd.Function): w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device) deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad) + if ctx.has_batch_dim: + x_grad = x_grad.unsqueeze(1) + return x_grad, w_grad