[zero] check whether gradients have inf and nan in gpu (#712)

This commit is contained in:
HELSON
2022-04-11 15:40:13 +08:00
committed by GitHub
parent 715b86eadd
commit dbd96fe90a
3 changed files with 90 additions and 11 deletions

View File

@@ -148,6 +148,9 @@ class ShardedModelV2(nn.Module):
self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard
# record whether gradients have inf or nan
self.overflow_counter = 0
def adjust_stateful_tensor_layout(self) -> None:
self._stateful_tensor_mgr.adjust_layout()
@@ -345,6 +348,11 @@ class ShardedModelV2(nn.Module):
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# record whether we have overflow
self.overflow_counter += torch.isinf(grad).any().item()
self.overflow_counter += torch.isnan(grad).any().item()
# move gradient to cpu
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad)

View File

@@ -118,7 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards
@@ -210,20 +210,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group in self.optim.param_groups:
for p in group['params']:
if has_inf_or_nan(p.grad):
self._found_overflow.fill_(1.0)
break
self._found_overflow.fill_(self.model.overflow_counter)
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.dp_process_group)
dist.all_reduce(self._found_overflow, group=self.dp_process_group)
# all-reduce over model parallel group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
dist.all_reduce(self._found_overflow, group=self.mp_process_group)
return self._found_overflow.item() > 0
@@ -259,6 +252,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
else:
# release saved gradient
p.colo_attr.saved_grad.set_null()
self.model.overflow_counter = 0 # set overflow counter to zero
def sync_grad(self):
pass