[fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility
This commit is contained in:
Hongxin Liu
2024-08-07 18:21:08 +08:00
committed by GitHub
parent 76ea16466f
commit ccabcf6485
6 changed files with 102 additions and 3 deletions

View File

@@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function):
if bias is not None:
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
# ensure x and w are row-major
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
x = x.contiguous()
w = w.contiguous()
ctx.x_shape = x.shape
ctx.has_bias = bias is not None
ctx.out_dtype = x.dtype