[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def get_process_group(self) -> 'ProcessGroup':
return self._tensor_spec.dist_spec.process_group
def get_tp_world_size(self) -> int:
return self._tensor_spec.dist_spec.process_group.tp_world_size()
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:

View File

@@ -5,6 +5,7 @@ from contextlib import contextmanager
import torch
import torch.distributed as dist
from packaging import version
from colossalai.logging import get_dist_logger
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@@ -64,7 +65,7 @@ class DistSpecManager:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = dist_spec.process_group.rank()
idx = dist_spec.process_group.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
@@ -91,8 +92,9 @@ class DistSpecManager:
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
@@ -108,14 +110,14 @@ class DistSpecManager:
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.size()
world_size = old_dist_spec.process_group.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device and " \
f"{dist.get_backend(old_dist_spec.process_group)} backend"
f"{old_dist_spec.process_group.backend} backend"
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
@@ -126,7 +128,7 @@ class DistSpecManager:
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size

View File

@@ -1,5 +1,5 @@
from enum import Enum
from torch.distributed import ProcessGroup
from colossalai.tensor import ProcessGroup
from typing import Optional, List
from numpy import prod
@@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
assert process_group is not None
assert process_group is not None and isinstance(process_group, ProcessGroup)
assert isinstance(dims, list) and isinstance(num_partitions, list)
assert len(dims) == len(num_partitions)
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))

View File

@@ -1,5 +1,6 @@
import torch
from typing import List, Optional
from colossalai.logging import get_dist_logger
class ProcessGroup:
@@ -41,12 +42,12 @@ class ProcessGroup:
if dp_degree and not tp_degree:
self._dp_degree = dp_degree
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
self._tp_degree = self._world_size / dp_degree
self._tp_degree = self._world_size // dp_degree
if not dp_degree and tp_degree:
self._tp_degree = tp_degree
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
self._dp_degree = self._world_size / tp_degree
self._dp_degree = self._world_size // tp_degree
self._tp_rank_list = []
self._dp_rank_list = []
@@ -58,12 +59,48 @@ class ProcessGroup:
if rank_id // self._tp_degree == self._rank // self._tp_degree:
self._tp_rank_list.append(rank_id)
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
assert backend == 'nccl'
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
self.logger = get_dist_logger('ProcessGroup')
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
@property
def backend(self):
return self._backend
def __eq__(self, obj: 'ProcessGroup') -> bool:
if not isinstance(obj, ProcessGroup):
return False
if self._rank != obj._rank:
assert False
if self._rank_list != obj._rank_list:
assert False
if self._tp_rank_list != obj._tp_rank_list:
assert False
if self._dp_rank_list != obj._dp_rank_list:
assert False
if self._backend != obj._backend:
assert False
if self._tp_degree != obj._tp_degree:
return False
if self._dp_degree != obj._dp_degree:
return False
return True
def rank(self):
return self._rank
def world_size(self):
return self._world_size
def tp_local_rank(self):
return self._rank % self._tp_degree
def dp_local_rank(self):
return self._rank // self._tp_degree
def dp_world_size(self):
return len(self._dp_rank_list)

View File

@@ -17,11 +17,12 @@ class TensorSpec(object):
self.compute_spec = compute_spec
self.dist_spec = dist_spec
# TODO(jiaruifang) actually need tp process group
def get_process_group(self):
return self.dist_spec.process_group
def get_process_group_size(self):
return dist.get_world_size(self.dist_spec.process_group)
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
def get_placement(self):
return self.dist_spec.placement
@@ -30,7 +31,7 @@ class TensorSpec(object):
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
or (len(self.dist_spec.num_partitions) == 1
and self.dist_spec.num_partitions[0] == 1) \
or (self.dist_spec.process_group.size() == 1)
or (self.dist_spec.process_group.tp_world_size() == 1)
def is_shard_1dcol(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD \