mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[zero]remove registered gradients hooks (#5687)
* remove registered hooks fix fix fix zero fix fix fix fix fix zero fix zero fix fix fix * fix fix fix
This commit is contained in:
@@ -735,7 +735,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# Get all working gradients and gradients to be synchronized.
|
||||
all_working_grads = _get_all_working_grads()
|
||||
grads_to_sync = _get_grads_to_sync(all_working_grads)
|
||||
if self.require_grad_sync and grads_to_sync is not None:
|
||||
if self._grad_store.require_grad_sync and grads_to_sync is not None:
|
||||
# Synchronize sequence parallelism gradients if required.
|
||||
SeqParallelUtils.allreduce_partial_data_grad(process_group=self.tp_pg, grads=grads_to_sync)
|
||||
else:
|
||||
@@ -759,7 +759,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# Call the superclass backward method to compute gradients.
|
||||
super().backward(loss, retain_graph)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
@@ -784,7 +784,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||
# Call the superclass backward_by_grad method to compute gradients.
|
||||
super().backward_by_grad(tensor, grad)
|
||||
|
||||
if self.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
if self._grad_store.require_grad_sync and self.model.shard_config.enable_sequence_parallelism:
|
||||
# If gradient synchronization is required, sync sequence parallelism gradients.
|
||||
self._sync_sp_grads()
|
||||
else:
|
||||
@@ -1272,7 +1272,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
|
||||
# run with gradients accumulation
|
||||
if model.require_grad_sync == False or (
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
|
||||
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer._grad_store.require_grad_sync == False
|
||||
):
|
||||
return outputs
|
||||
|
||||
|
Reference in New Issue
Block a user