[shardformer] support pp+tp+zero1 tests (#4531)

* [shardformer] fix opt test hanging

* fix

* test

* test

* test

* fix test

* fix test

* remove print

* add fix

* [shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

[shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1

* [shardformer] pp+tp+zero1
This commit is contained in:
flybird11111
2023-08-30 21:29:18 +08:00
committed by GitHub
parent d367b88785
commit ec18fc7340
10 changed files with 101 additions and 3 deletions

View File

@@ -333,12 +333,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
# in lower stage which grad is transfered by higher stage
# we need to pass the optim state down.
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient