mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[Tensor] add cpu group to ddp (#1200)
This commit is contained in:
@@ -114,10 +114,8 @@ class DistSpecManager:
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device and " \
|
||||
f"{old_dist_spec.process_group.backend} backend"
|
||||
assert tensor.device.type == "cuda", "Currently, only CUDA Tensors are supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
|
@@ -18,7 +18,6 @@ class ProcessGroup:
|
||||
def __init__(self,
|
||||
rank: Optional[int] = None,
|
||||
ranks: Optional[List[int]] = None,
|
||||
backend: str = 'nccl',
|
||||
tp_degree: Optional[int] = None,
|
||||
dp_degree: Optional[int] = None) -> None:
|
||||
assert torch.distributed.is_initialized(), f"ProcessGroup must be used after distributed initialized"
|
||||
@@ -32,7 +31,6 @@ class ProcessGroup:
|
||||
else:
|
||||
self._rank_list = ranks
|
||||
|
||||
self._backend = backend
|
||||
self._world_size = len(self._rank_list)
|
||||
|
||||
if dp_degree is None and tp_degree is None:
|
||||
@@ -59,16 +57,26 @@ class ProcessGroup:
|
||||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
assert backend == 'nccl'
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='nccl')
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='nccl')
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
||||
self.logger.info(
|
||||
f'{self._rank} NCCL initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
|
||||
self._has_cpu_groups = False
|
||||
|
||||
def set_cpu_groups(self):
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
self.logger.info(
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
self._cpu_tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend='gloo')
|
||||
self._cpu_dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend='gloo')
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self._backend
|
||||
def has_cpu_groups(self):
|
||||
return self._has_cpu_groups
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
@@ -81,8 +89,6 @@ class ProcessGroup:
|
||||
assert False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
if self._backend != obj._backend:
|
||||
assert False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
@@ -112,3 +118,9 @@ class ProcessGroup:
|
||||
|
||||
def tp_process_group(self):
|
||||
return self._tp_process_group
|
||||
|
||||
def cpu_dp_process_group(self):
|
||||
return self._cpu_dp_process_group
|
||||
|
||||
def cpu_tp_process_group(self):
|
||||
return self._cpu_tp_process_group
|
||||
|
@@ -17,13 +17,9 @@ class TensorSpec(object):
|
||||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
# TODO(jiaruifang) actually need tp process group
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
||||
|
Reference in New Issue
Block a user