MemStatsCollectorStatic (#1765)

This commit is contained in:
Zihao
2022-11-07 16:49:03 +08:00
committed by GitHub
parent 327d07c44a
commit 20e255d4e8
4 changed files with 142 additions and 11 deletions

View File

@@ -267,7 +267,7 @@ class ZeroDDP(ColoDDP):
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32: