[fp8] use torch compile (torch >= 2.3.0) (#5979)

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig
This commit is contained in:
botbw 2024-08-09 15:51:06 +08:00 committed by GitHub
parent 8241c0c054
commit e4aadeee20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,13 +1,14 @@
from typing import Any, Optional from typing import Any, Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor): def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r""" r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling. casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args: Args:
@ -624,7 +625,13 @@ class _LinearFp8(torch.autograd.Function):
ctx.inv_scale_x = inv_scale_x ctx.inv_scale_x = inv_scale_x
ctx.inv_scale_w = inv_scale_w ctx.inv_scale_w = inv_scale_w
out = torch._scaled_mm( out = torch._scaled_mm(
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w x_fp8,
ctx.w_fp8_t,
bias=bias,
out_dtype=ctx.out_dtype,
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0] )[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0]) return out.reshape(*ctx.x_shape[:-1], w.shape[0])
@ -638,6 +645,7 @@ class _LinearFp8(torch.autograd.Function):
out_dtype=ctx.out_dtype, out_dtype=ctx.out_dtype,
scale_a=out_grad_scale, scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w, scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0] )[0]
w_grad = torch._scaled_mm( w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(), out_grad_fp8.t().contiguous(),
@ -645,6 +653,7 @@ class _LinearFp8(torch.autograd.Function):
out_dtype=ctx.out_dtype, out_dtype=ctx.out_dtype,
scale_a=out_grad_scale, scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x, scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0] )[0]
bias_grad = None bias_grad = None
if ctx.has_bias: if ctx.has_bias:
@ -652,5 +661,13 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
else:
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias) return _LinearFp8.apply(input, weight, bias)