mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fp8] hotfix backward hook (#6053)
* [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulation
This commit is contained in:
@@ -100,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
|
||||
with ctx:
|
||||
with self._hook_context():
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def _force_wait_all_gather(self):
|
||||
for p in self.module.parameters():
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _hook_context(self):
|
||||
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
@@ -520,7 +522,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
||||
optimizer, **zero_optim_kwargs, verbose=self.verbose
|
||||
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
|
||||
)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
Reference in New Issue
Block a user