mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility
This commit is contained in:
@@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function):
|
||||
if bias is not None:
|
||||
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
||||
# ensure x and w are row-major
|
||||
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
ctx.x_shape = x.shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.out_dtype = x.dtype
|
||||
|
Reference in New Issue
Block a user