mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
[zero]fix zero ckptIO with offload (#4529)
* fix zero ckptio with offload * fix load device * saved tensors in ckpt should be on CPU * fix unit test * fix unit test * add clear cache * save memory for CI
This commit is contained in:
@@ -16,19 +16,21 @@ from colossalai.testing import (
|
||||
)
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
# only test 2 is fine
|
||||
@clear_cache_before_run()
|
||||
@parameterize('stage', [2])
|
||||
@parameterize('shard', [True, False])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
|
||||
@parameterize('offload', [False, True])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32, cpu_offload=offload)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = HybridAdam((model.parameters()), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
x = torch.randn(1, 3, 224, 224, device='cuda')
|
||||
output = model(x)
|
||||
loss = criterion(output)
|
||||
booster.backward(loss, optimizer)
|
||||
@@ -50,15 +52,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_low_level_zero_checkpointIO()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_low_level_zero_checkpointIO():
|
||||
spawn(run_dist, 2)
|
||||
|
||||
|
@@ -37,7 +37,7 @@ def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
atol = 4e-3
|
||||
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
b = b.detach().to(dtype).to(a.device)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
Reference in New Issue
Block a user