mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
|
||||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def get_process_group(self) -> 'ProcessGroup':
|
||||
return self._tensor_spec.dist_spec.process_group
|
||||
|
||||
def get_tp_world_size(self) -> int:
|
||||
return self._tensor_spec.dist_spec.process_group.tp_world_size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
|
@@ -5,6 +5,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
@@ -64,7 +65,7 @@ class DistSpecManager:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
idx = dist_spec.process_group.rank()
|
||||
idx = dist_spec.process_group.tp_local_rank()
|
||||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
@@ -91,8 +92,9 @@ class DistSpecManager:
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
@@ -108,14 +110,14 @@ class DistSpecManager:
|
||||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
world_size = old_dist_spec.process_group.size()
|
||||
world_size = old_dist_spec.process_group.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
|
||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device and " \
|
||||
f"{dist.get_backend(old_dist_spec.process_group)} backend"
|
||||
f"{old_dist_spec.process_group.backend} backend"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
@@ -126,7 +128,7 @@ class DistSpecManager:
|
||||
|
||||
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
|
||||
|
||||
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||
|
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from typing import Optional, List
|
||||
from numpy import prod
|
||||
|
||||
@@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
||||
|
||||
|
||||
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
assert process_group is not None
|
||||
assert process_group is not None and isinstance(process_group, ProcessGroup)
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
|
||||
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
|
||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
@@ -41,12 +42,12 @@ class ProcessGroup:
|
||||
if dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size / dp_degree
|
||||
self._tp_degree = self._world_size // dp_degree
|
||||
|
||||
if not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size / tp_degree
|
||||
self._dp_degree = self._world_size // tp_degree
|
||||
|
||||
self._tp_rank_list = []
|
||||
self._dp_rank_list = []
|
||||
@@ -58,12 +59,48 @@ class ProcessGroup:
|
||||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
|
||||
assert backend == 'nccl'
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self._backend
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
assert False
|
||||
if self._rank_list != obj._rank_list:
|
||||
assert False
|
||||
if self._tp_rank_list != obj._tp_rank_list:
|
||||
assert False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
if self._backend != obj._backend:
|
||||
assert False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
return False
|
||||
return True
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tp_local_rank(self):
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
def dp_local_rank(self):
|
||||
return self._rank // self._tp_degree
|
||||
|
||||
def dp_world_size(self):
|
||||
return len(self._dp_rank_list)
|
||||
|
||||
|
@@ -17,11 +17,12 @@ class TensorSpec(object):
|
||||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
# TODO(jiaruifang) actually need tp process group
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group)
|
||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
@@ -30,7 +31,7 @@ class TensorSpec(object):
|
||||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
or (self.dist_spec.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
|
Reference in New Issue
Block a user