[checkpoint] save sharded optimizer states (#1237)

This commit is contained in:
Jiarui Fang
2022-07-08 16:33:13 +08:00
committed by GitHub
parent 4a76084dc9
commit 20da6e48c8
3 changed files with 28 additions and 19 deletions

View File

@@ -93,20 +93,17 @@ class ProcessGroup:
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self._has_cpu_groups = False
self._cpu_dp_process_group = None
self._cpu_tp_process_group = None
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def set_cpu_groups(self):
if self.has_cpu_groups:
return
self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
@property
def has_cpu_groups(self):
@@ -152,13 +149,15 @@ class ProcessGroup:
return len(self._tp_rank_list)
def dp_process_group(self):
return self._dp_process_group
# return self._dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
return self._tp_process_group
# return self._tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
return self._cpu_dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
return self._cpu_tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')

View File

@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
"""
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
model_state = {'epoch': epoch, 'model': model.state_dict()}
if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
# 1. convert SHARD ColoTensor to REPLICATE
# only rank 0 saves the REPLICATE tensors.
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))