mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[fp8] add fallback and make compile option configurable (#6092)
This commit is contained in:
parent
3b1d7d1ae8
commit
5ddad486ca
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from .fp8_config import dynamic_kernel
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0")
|
||||
SCALE_BYTES = 4
|
||||
try:
|
||||
@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
|
||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
|
||||
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel)
|
||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0:
|
||||
return F.linear(input, weight, bias)
|
||||
out = _linear_fp8(input, weight, bias)
|
||||
return out
|
||||
|
1
colossalai/quantization/fp8_config.py
Normal file
1
colossalai/quantization/fp8_config.py
Normal file
@ -0,0 +1 @@
|
||||
dynamic_kernel: bool = False
|
Loading…
Reference in New Issue
Block a user