added buffer sync to naive amp model wrapper (#291)

This commit is contained in:
Frank Lee
2022-03-02 16:47:17 +08:00
parent 8d653af408
commit e17e54e32a
4 changed files with 191 additions and 46 deletions

View File

@@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
from .sharded_optim import ShardedOptimizer
def convert_to_zero(model: nn.Module,
optimizer: Optimizer,
level: int,
zero_config: dict):
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
@@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
if level in [1, 2]:
if level == 2:
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
if 'partition_grad' in zero_config:
assert zero_config['partition_grad'], \
'Sharded Optimizer requires partition_grad to be True'
else:
zero_config['partiton_grad'] = True
model = NaiveAMPModel(model, output_to_fp32=True)
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
optimizer = ShardedOptimizer(optimizer, **zero_config)
else:
model = ShardedModel(module=model, **zero_config)
return model, optimizer
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']