[shardformer] hybridparallelplugin support gradients accumulation. (#5246)

* support gradients acc

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

fix

* fix

fix

* fix

fix

fix
This commit is contained in:
flybird11111
2024-01-17 15:22:33 +08:00
committed by GitHub
parent 2a0558d8ec
commit 46e091651b
2 changed files with 174 additions and 8 deletions

View File

@@ -165,7 +165,6 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
Returns:
None
"""
if self.tp_group.size() > 1 and self.shard_config.enable_sequence_parallelism:
if grads is not None:
# Synchronize provided gradient tensors across the tensor parallelism group.
@@ -487,7 +486,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
@@ -513,7 +511,6 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
Returns:
None
"""
# Call the superclass backward method to compute gradients.
super().backward_by_grad(tensor, grad)
@@ -674,7 +671,6 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
Returns:
None
"""
# Call the superclass `_sync_grad` method to synchronize gradients.
super()._sync_grad()
@@ -1081,7 +1077,7 @@ class HybridParallelPlugin(PipelinePluginBase):
return True
def support_no_sync(self) -> bool:
return False
return True
def control_checkpoint_io(self) -> bool:
return True
@@ -1175,9 +1171,14 @@ class HybridParallelPlugin(PipelinePluginBase):
model, data_iter, criterion, optimizer, return_loss, return_outputs
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
):
return outputs
# Synchronize the grads of shared parameters of the model.
model.sync_shared_params()
# Synchronize sequence parallelism gradients of the model.
model.sync_sp_grads()
@@ -1241,5 +1242,8 @@ class HybridParallelPlugin(PipelinePluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert (
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()