mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 20:54:35 +00:00
[checkpoint] save sharded optimizer states (#1237)
This commit is contained in:
@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
|
||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
|
||||
model_state = {'epoch': epoch, 'model': model.state_dict()}
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
|
||||
# 1. convert SHARD ColoTensor to REPLICATE
|
||||
# only rank 0 saves the REPLICATE tensors.
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||
|
||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user