[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
This commit is contained in:
Hongxin Liu
2024-09-11 16:11:25 +08:00
committed by GitHub
parent c54c4fcd15
commit 13946c4448
6 changed files with 31 additions and 17 deletions

View File

@@ -100,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
with ctx:
with self._hook_context():
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@@ -520,7 +522,7 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)