mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[fp8] support gemini plugin (#5978)
* [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmark
This commit is contained in:
@@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function):
|
||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(x, w, bias)
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
Reference in New Issue
Block a user