[fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
This commit is contained in:
Hongxin Liu
2024-08-09 14:09:48 +08:00
committed by GitHub
parent 4b9bec8176
commit 8241c0c054
7 changed files with 21 additions and 7 deletions

View File

@@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(x, w, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

View File

@@ -0,0 +1,23 @@
import torch.nn.functional as F
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.param_op_hook import ColoParamOpHook
class FP8Hook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass
def post_forward(self, params) -> None:
pass
def pre_backward(self, params) -> None:
pass
def post_backward(self, params) -> None:
pass
def rewrite_op(self, func):
if func is F.linear:
return linear_fp8
return func