mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[hotfix] fix ddp for unit test test_gpt2 (#1326)
This commit is contained in:
@@ -21,7 +21,7 @@ class PyTorchProcessGroupDict(metaclass=SingletonMeta):
|
||||
if pg_key not in self.dict:
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'NCCL initialize TP group on {rank_list}', ranks=[0])
|
||||
self.logger.info(f'NCCL initialize ProcessGroup on {rank_list}', ranks=[0])
|
||||
|
||||
self.dict[pg_key] = torch.distributed.new_group(ranks=rank_list, backend=backend)
|
||||
return self.dict[pg_key]
|
||||
@@ -63,7 +63,6 @@ class ProcessGroup:
|
||||
self._rank_list = ranks
|
||||
self._rank_list.sort() # ensure that the list is in order
|
||||
|
||||
self._rank_idx = self._rank_list.index(self._rank)
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
if dp_degree is None and tp_degree is None:
|
||||
@@ -84,19 +83,22 @@ class ProcessGroup:
|
||||
f"the world size {self._world_size} should equals to the product of DP degree {self._dp_degree}" \
|
||||
f"and TP degree {self._tp_degree}"
|
||||
|
||||
self._tp_rank_list = []
|
||||
self._dp_rank_list = []
|
||||
self._tp_rank_list = None
|
||||
self._dp_rank_list = None
|
||||
|
||||
for idx, rank_id in enumerate(self._rank_list):
|
||||
# idx and self._rank_idx in the same tp group
|
||||
if idx % self._tp_degree == self._rank_idx % self._tp_degree:
|
||||
self._dp_rank_list.append(rank_id)
|
||||
if idx // self._tp_degree == self._rank_idx // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
for i in range(self._dp_degree):
|
||||
i_tp_list = [self._rank_list[i * self._tp_degree + j] for j in range(self._tp_degree)]
|
||||
PYTORCHPGDICT_.get(i_tp_list, 'nccl')
|
||||
if self._rank in i_tp_list:
|
||||
self._tp_rank_list = i_tp_list
|
||||
|
||||
for j in range(self._tp_degree):
|
||||
j_dp_list = [self._rank_list[i * self._tp_degree + j] for i in range(self._dp_degree)]
|
||||
PYTORCHPGDICT_.get(j_dp_list, 'nccl')
|
||||
if self._rank in j_dp_list:
|
||||
self._dp_rank_list = j_dp_list
|
||||
|
||||
self._has_cpu_groups = False
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'nccl')
|
||||
self.is_init = True
|
||||
|
||||
def set_cpu_groups(self):
|
||||
@@ -106,6 +108,7 @@ class ProcessGroup:
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
self._has_cpu_groups = True
|
||||
|
||||
@property
|
||||
def has_cpu_groups(self):
|
||||
@@ -162,7 +165,9 @@ class ProcessGroup:
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'nccl')
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
assert self._has_cpu_groups
|
||||
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
|
Reference in New Issue
Block a user