mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 02:20:49 +00:00
fix the merge
This commit is contained in:
@@ -767,12 +767,12 @@ class _ReduceForward(torch.autograd.Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
return _reduce(input_, process_group)
|
||||
def forward(ctx, input_, process_group, fp8_communication=False):
|
||||
return _reduce(input_, process_group, fp8_communication, fp8_format="e4m3")
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
return grad_output, None, None
|
||||
|
||||
|
||||
class _ReduceBackward(torch.autograd.Function):
|
||||
|
Reference in New Issue
Block a user