[checkpoint] save sharded optimizer states (#1237)

This commit is contained in:
Jiarui Fang
2022-07-08 16:33:13 +08:00
committed by GitHub
parent 4a76084dc9
commit 20da6e48c8
3 changed files with 28 additions and 19 deletions

View File

@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
model_state = {'epoch': epoch, 'model': model.state_dict()}
if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
# 1. convert SHARD ColoTensor to REPLICATE
# only rank 0 saves the REPLICATE tensors.
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))