mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[checkpoint] save sharded optimizer states (#1237)
This commit is contained in:
@@ -93,20 +93,17 @@ class ProcessGroup:
|
||||
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
self._has_cpu_groups = False
|
||||
self._cpu_dp_process_group = None
|
||||
self._cpu_tp_process_group = None
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def set_cpu_groups(self):
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
self.logger.info(
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
@property
|
||||
def has_cpu_groups(self):
|
||||
@@ -152,13 +149,15 @@ class ProcessGroup:
|
||||
return len(self._tp_rank_list)
|
||||
|
||||
def dp_process_group(self):
|
||||
return self._dp_process_group
|
||||
# return self._dp_process_group
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
|
||||
def tp_process_group(self):
|
||||
return self._tp_process_group
|
||||
# return self._tp_process_group
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
return self._cpu_dp_process_group
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
return self._cpu_tp_process_group
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
|
@@ -32,10 +32,15 @@ def save_checkpoint(dire: str,
|
||||
optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
|
||||
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
|
||||
"""
|
||||
model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
|
||||
model_state = {'epoch': epoch, 'model': model.state_dict()}
|
||||
if dist.get_rank() == 0:
|
||||
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
|
||||
|
||||
# TODO() If use tensor parallelism, optim_states contain SHARD ColoTensors.
|
||||
# 1. convert SHARD ColoTensor to REPLICATE
|
||||
# only rank 0 saves the REPLICATE tensors.
|
||||
optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict()}
|
||||
|
||||
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user