mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 14:10:29 +00:00
fix
This commit is contained in:
parent
53823118f2
commit
f7acfa1bd5
@ -767,11 +767,14 @@ class _ReduceForward(torch.autograd.Function):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
def forward(ctx, input_, process_group, grad_scale=None, fp8_communication=False):
|
||||||
|
ctx.grad_scale = grad_scale
|
||||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
|
if ctx.grad_scale is not None:
|
||||||
|
grad_output = grad_output * ctx.grad_scale
|
||||||
return grad_output, None, None
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
|
@ -555,7 +555,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||||||
else:
|
else:
|
||||||
if self.seq_parallel_mode is None:
|
if self.seq_parallel_mode is None:
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reduce_forward(output_parallel, self.process_group, self.fp8_communication)
|
output = reduce_forward(output_parallel, self.process_group, fp8_communication=self.fp8_communication)
|
||||||
elif self.seq_parallel_mode == "split_gather":
|
elif self.seq_parallel_mode == "split_gather":
|
||||||
output_parallel = torch.matmul(input_, self.weight)
|
output_parallel = torch.matmul(input_, self.weight)
|
||||||
output = reducescatter_forward_gather_backward(
|
output = reducescatter_forward_gather_backward(
|
||||||
|
Loading…
Reference in New Issue
Block a user