mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
added buffer sync to naive amp model wrapper (#291)
This commit is contained in:
@@ -9,10 +9,7 @@ from .sharded_model import ShardedModel
|
||||
from .sharded_optim import ShardedOptimizer
|
||||
|
||||
|
||||
def convert_to_zero(model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
level: int,
|
||||
zero_config: dict):
|
||||
def convert_to_zero(model: nn.Module, optimizer: Optimizer, level: int, zero_config: dict):
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
@@ -31,11 +28,16 @@ def convert_to_zero(model: nn.Module,
|
||||
assert 1 <= level <= 3, 'Only ZERO Optimizer Level 1-3 are provided'
|
||||
if level in [1, 2]:
|
||||
if level == 2:
|
||||
assert config['partition_grad'], 'ZeRO Optimizer requires partition_grad to be True'
|
||||
if 'partition_grad' in zero_config:
|
||||
assert zero_config['partition_grad'], \
|
||||
'Sharded Optimizer requires partition_grad to be True'
|
||||
else:
|
||||
zero_config['partiton_grad'] = True
|
||||
model = NaiveAMPModel(model, output_to_fp32=True)
|
||||
optimizer = ShardedOptimizer(model.parameters(), *zero_config)
|
||||
optimizer = ShardedOptimizer(optimizer, **zero_config)
|
||||
else:
|
||||
model = ShardedModel(module=model, **zero_config)
|
||||
return model, optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zero', 'ShardedModel', 'ShardedOptimizer']
|
||||
|
Reference in New Issue
Block a user