[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
This commit is contained in:
Hongxin Liu
2024-09-11 16:11:25 +08:00
committed by GitHub
parent c54c4fcd15
commit 13946c4448
6 changed files with 31 additions and 17 deletions

View File

@@ -216,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
with self._hook_context():
return super().forward(*args, **kwargs)
def unwrap(self):
@@ -229,12 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def get_param_info(optim: Optimizer):
@@ -317,7 +313,8 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -540,7 +537,8 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@@ -683,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
self.model = model
self.param_info = param_info
@@ -712,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)
def sync_dp_grads(self):
@@ -1206,6 +1207,7 @@ class HybridParallelPlugin(PipelinePluginBase):
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
self.max_norm = max_norm
@@ -1371,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._wait_all_gather():
with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)