add scatter/gather optim for pipeline (#123)

This commit is contained in:
ver217
2022-01-07 13:22:22 +08:00
committed by GitHub
parent 404e6f88ed
commit 293fb40c42
5 changed files with 166 additions and 56 deletions

View File

@@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
# initialize amp
amp_mode = None
if fp16_cfg is not None and fp16_cfg.mode is not None:
# TODO: pipeline only support NAIVE AMP
cfg_ = fp16_cfg.copy()
amp_mode = cfg_.pop('mode')
if is_using_pp():
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
if amp_mode == AMP_TYPE.NAIVE:
cfg_['clip_grad'] = clip_grad_norm
model, optimizer, criterion = convert_to_amp(model=model,