[hotfix] fix ddp for unit test test_gpt2 (#1326)

This commit is contained in:
HELSON
2022-07-15 18:19:52 +08:00
committed by GitHub
parent 250be4d31e
commit d49708ae43
4 changed files with 86 additions and 69 deletions

View File

@@ -21,7 +21,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
if pg_key not in self.dict:
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'NCCL initialize TP group on {rank_list}', ranks=[0])
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
return self.dict[pg_key]
@@ -63,7 +63,6 @@ class ProcessGroup:
self._rank_list = ranks
self._rank_list.sort() # ensure that the list is in order
self._rank_idx = self._rank_list.index(self._rank)
self._world_size = len(self._rank_list)
if dp_degree is None and tp_degree is None:
@@ -84,19 +83,22 @@ class ProcessGroup:
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
f"and TP degree {self._tp_degree}"
self._tp_rank_list = []
self._dp_rank_list = []
self._tp_rank_list = None
self._dp_rank_list = None
for idx, rank_id in enumerate(self._rank_list):
# idx and self._rank_idx in the same tp group
if idx % self._tp_degree == self._rank_idx % self._tp_degree:
self._dp_rank_list.append(rank_id)
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
self._tp_rank_list.append(rank_id)
for i in range(self._dp_degree):
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
if self._rank in i_tp_list:
self._tp_rank_list = i_tp_list
for j in range(self._tp_degree):
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
if self._rank in j_dp_list:
self._dp_rank_list = j_dp_list
self._has_cpu_groups = False
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
self.is_init = True
def set_cpu_groups(self):
@@ -106,6 +108,7 @@ class ProcessGroup:
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
self._has_cpu_groups = True
@property
def has_cpu_groups(self):
@@ -162,7 +165,9 @@ class ProcessGroup:
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
def cpu_dp_process_group(self):
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
def cpu_tp_process_group(self):
assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')