diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index ed7b35233..bb403224f 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -775,7 +775,7 @@ class _ReduceForward(torch.autograd.Function): def backward(ctx, grad_output): if ctx.grad_scale is not None: grad_output = grad_output * ctx.grad_scale - return grad_output, None, None + return grad_output, None, None, None class _ReduceBackward(torch.autograd.Function):