[checkpoint] save sharded optimizer states (#1237)

This commit is contained in:
Jiarui Fang
2022-07-08 16:33:13 +08:00
committed by GitHub
parent 4a76084dc9
commit 20da6e48c8
3 changed files with 28 additions and 19 deletions

View File

@@ -93,20 +93,17 @@ class ProcessGroup:
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
self._dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self._has_cpu_groups = False
self._cpu_dp_process_group = None
self._cpu_tp_process_group = None
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def set_cpu_groups(self):
if self.has_cpu_groups:
return
self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
self._cpu_tp_process_group = PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
self._cpu_dp_process_group = PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
@property
def has_cpu_groups(self):
@@ -152,13 +149,15 @@ class ProcessGroup:
return len(self._tp_rank_list)
def dp_process_group(self):
return self._dp_process_group
# return self._dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
def tp_process_group(self):
return self._tp_process_group
# return self._tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
return self._cpu_dp_process_group
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
return self._cpu_tp_process_group
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')