mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[fp8] use torch compile (torch >= 2.3.0) (#5979)
* [fp8] use torch compile (torch >= 2.4.0) * [fp8] set use_fast_accum in linear * [chore] formal version check * [chore] fix sig
This commit is contained in:
parent
8241c0c054
commit
e4aadeee20
@ -1,13 +1,14 @@
|
|||||||
from typing import Any, Optional
|
from typing import Any, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from packaging.version import Version
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
|
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||||
Args:
|
Args:
|
||||||
@ -624,7 +625,13 @@ class _LinearFp8(torch.autograd.Function):
|
|||||||
ctx.inv_scale_x = inv_scale_x
|
ctx.inv_scale_x = inv_scale_x
|
||||||
ctx.inv_scale_w = inv_scale_w
|
ctx.inv_scale_w = inv_scale_w
|
||||||
out = torch._scaled_mm(
|
out = torch._scaled_mm(
|
||||||
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
|
x_fp8,
|
||||||
|
ctx.w_fp8_t,
|
||||||
|
bias=bias,
|
||||||
|
out_dtype=ctx.out_dtype,
|
||||||
|
scale_a=inv_scale_x,
|
||||||
|
scale_b=inv_scale_w,
|
||||||
|
use_fast_accum=True,
|
||||||
)[0]
|
)[0]
|
||||||
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||||
|
|
||||||
@ -638,6 +645,7 @@ class _LinearFp8(torch.autograd.Function):
|
|||||||
out_dtype=ctx.out_dtype,
|
out_dtype=ctx.out_dtype,
|
||||||
scale_a=out_grad_scale,
|
scale_a=out_grad_scale,
|
||||||
scale_b=ctx.inv_scale_w,
|
scale_b=ctx.inv_scale_w,
|
||||||
|
use_fast_accum=True,
|
||||||
)[0]
|
)[0]
|
||||||
w_grad = torch._scaled_mm(
|
w_grad = torch._scaled_mm(
|
||||||
out_grad_fp8.t().contiguous(),
|
out_grad_fp8.t().contiguous(),
|
||||||
@ -645,6 +653,7 @@ class _LinearFp8(torch.autograd.Function):
|
|||||||
out_dtype=ctx.out_dtype,
|
out_dtype=ctx.out_dtype,
|
||||||
scale_a=out_grad_scale,
|
scale_a=out_grad_scale,
|
||||||
scale_b=ctx.inv_scale_x,
|
scale_b=ctx.inv_scale_x,
|
||||||
|
use_fast_accum=True,
|
||||||
)[0]
|
)[0]
|
||||||
bias_grad = None
|
bias_grad = None
|
||||||
if ctx.has_bias:
|
if ctx.has_bias:
|
||||||
@ -652,5 +661,13 @@ class _LinearFp8(torch.autograd.Function):
|
|||||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
|
||||||
|
|
||||||
|
@torch.compile(mode="reduce-overhead", fullgraph=True)
|
||||||
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _LinearFp8.apply(input, weight, bias)
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
Loading…
Reference in New Issue
Block a user