mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-05-03 01:19:15 +00:00
[hotfix] fix zero ddp warmup check (#2545)
This commit is contained in:
@@ -58,6 +58,10 @@ class GeminiManager:
|
||||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
@property
|
||||
def need_warmup(self) -> bool:
|
||||
return self.policy_name in ('auto', 'const')
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
|
||||
|
||||
@@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP):
|
||||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
Reference in New Issue
Block a user