[zero] add warning for ignored parameters (#2446)

This commit is contained in:
HELSON
2023-01-11 15:30:09 +08:00
committed by GitHub
parent 39163417a1
commit 2bfeb24308
2 changed files with 20 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
@@ -78,8 +79,16 @@ class ZeroOptimizer(ColossalaiOptimizer):
if self.clipping_flag:
assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now"
params_list = [p for p in module.parameters() if not is_ddp_ignored(p)]
for p, fp32_p in zip(params_list, module.fp32_params):
ddp_param_list = []
for name, param in module.named_parameters():
if is_ddp_ignored(param):
if param.requires_grad:
warnings.warn(f"Parameter `{name}` is ignored by DDP but requires gradient! "
"You should handle its optimizer update by yourself!")
else:
ddp_param_list.append(param)
for p, fp32_p in zip(ddp_param_list, module.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
if chunk_16 not in self.chunk16_set:
chunk_16.l2_norm_flag = self.clipping_flag
@@ -290,6 +299,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list = list()
for param in group['params']:
if is_ddp_ignored(param):
continue
chunk16 = self.chunk_manager.get_chunk(param)
range_pair = get_range_pair(chunk16, param)
if range_pair[0] >= range_pair[1]: