diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index bfe408065..aed9d8351 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -1082,8 +1082,8 @@ def split_forward_gather_backward(input_, dim, process_group, grad_scale=None, f return _SplitForwardGatherBackward.apply(input_, dim, process_group, grad_scale, fp8_communication) -def reduce_forward(input_, process_group, grad_scale=None, fp8_communication=False): - return _ReduceForward.apply(input_, process_group, grad_scale, fp8_communication) +def reduce_forward(input_, process_group, fp8_communication=False): + return _ReduceForward.apply(input_, process_group, fp8_communication) def reduce_backward(input_, process_group, fp8_communication=False):