mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
add scatter/gather optim for pipeline (#123)
This commit is contained in:
@@ -290,9 +290,10 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
|
||||
# initialize amp
|
||||
amp_mode = None
|
||||
if fp16_cfg is not None and fp16_cfg.mode is not None:
|
||||
# TODO: pipeline only support NAIVE AMP
|
||||
cfg_ = fp16_cfg.copy()
|
||||
amp_mode = cfg_.pop('mode')
|
||||
if is_using_pp():
|
||||
assert amp_mode == AMP_TYPE.NAIVE, 'Pipeline only support NaiveAMP currently'
|
||||
if amp_mode == AMP_TYPE.NAIVE:
|
||||
cfg_['clip_grad'] = clip_grad_norm
|
||||
model, optimizer, criterion = convert_to_amp(model=model,
|
||||
|
Reference in New Issue
Block a user