mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
Allow to compute when bsz == 1
This commit is contained in:
parent
16e46efd79
commit
b42472859a
@ -800,7 +800,11 @@ class _LinearFp8(torch.autograd.Function):
|
||||
scale_a=inv_scale_x,
|
||||
scale_b=inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
)
|
||||
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
|
||||
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||
|
||||
@staticmethod
|
||||
@ -814,7 +818,11 @@ class _LinearFp8(torch.autograd.Function):
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
)
|
||||
|
||||
if isinstance(x_grad, tuple):
|
||||
x_grad = x_grad[0]
|
||||
|
||||
w_grad = torch._scaled_mm(
|
||||
out_grad_fp8.t().contiguous(),
|
||||
ctx.x_fp8.t().contiguous().t(),
|
||||
@ -822,7 +830,11 @@ class _LinearFp8(torch.autograd.Function):
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_x,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
)
|
||||
|
||||
if isinstance(w_grad, tuple):
|
||||
w_grad = w_grad[0]
|
||||
|
||||
bias_grad = None
|
||||
if ctx.has_bias:
|
||||
bias_grad = out_grad.sum(0)
|
||||
@ -835,8 +847,14 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
def forward(ctx: Any, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
|
||||
if not (x.dim() == 2 and w.dim() == 2):
|
||||
raise ValueError("Batched fp8 deep_gemm is not supported")
|
||||
has_batch_dim = False
|
||||
if x.dim() == 3:
|
||||
has_batch_dim = True
|
||||
if x.size(1) != 1:
|
||||
raise ValueError(f"Batched fp8 deep_gemm is not supported, found x shape: {x.shape}")
|
||||
x = x.squeeze(1)
|
||||
ctx.has_batch_dim = has_batch_dim
|
||||
|
||||
# x: (m, k), w: (n, k)
|
||||
# x @ w_t -> (m, k) @ (k, n) -> deep_gemm((m, k), (n, k))
|
||||
(m, k), (n, _) = x.shape, w.shape
|
||||
@ -848,12 +866,17 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
|
||||
ctx.w_t_per_plk = per_block_cast_to_fp8(w.t())
|
||||
ctx.x_t_per_blk = per_block_cast_to_fp8(x.t())
|
||||
ctx.mnk = (m, n, k)
|
||||
if has_batch_dim:
|
||||
out = out.unsqueeze(1)
|
||||
return out
|
||||
|
||||
def backward(ctx: Any, o_grad: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# o_grad: (m, n)
|
||||
# x_grad: (m, k) -> (m, n) @ (n, k) -> deep_gemm((m, n), (k, n))
|
||||
# w_grad: (n, k) -> (m, n).t() @ (m, k) -> deep_gemm((m, n).t(), (k, m))
|
||||
if ctx.has_batch_dim:
|
||||
o_grad = o_grad.squeeze(1)
|
||||
|
||||
m, n, k = ctx.mnk
|
||||
o_per_tok = per_token_cast_to_fp8(o_grad)
|
||||
|
||||
@ -864,6 +887,9 @@ class _LinearFp8DeepGemm(torch.autograd.Function):
|
||||
w_grad = torch.empty((n, k), dtype=torch.bfloat16, device=o_grad.device)
|
||||
deepgemm_fp8_gemm(o_grad_t_per_tok, ctx.x_t_per_blk, w_grad)
|
||||
|
||||
if ctx.has_batch_dim:
|
||||
x_grad = x_grad.unsqueeze(1)
|
||||
|
||||
return x_grad, w_grad
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user