[refactor] remove gpc dependency in colotensor's _ops (#1189)

This commit is contained in:
Jiarui Fang
2022-07-04 18:54:37 +08:00
committed by GitHub
parent abf6a262dc
commit 060b917daf
33 changed files with 499 additions and 357 deletions

View File

@@ -5,6 +5,7 @@ from contextlib import contextmanager
import torch
import torch.distributed as dist
from packaging import version
from colossalai.logging import get_dist_logger
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
@@ -64,7 +65,7 @@ class DistSpecManager:
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
chunk = tensor
idx = dist_spec.process_group.rank()
idx = dist_spec.process_group.tp_local_rank()
num_parts = prod(dist_spec.num_partitions)
for i, dim in enumerate(dist_spec.dims):
num_parts //= dist_spec.num_partitions[i]
@@ -91,8 +92,9 @@ class DistSpecManager:
saved_dev = tensor.device
tensor.data = tensor.data.cuda()
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
assert tensor.device.type == 'cuda'
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
new_buffer = []
dim = old_dist_spec.dims[i]
@@ -108,14 +110,14 @@ class DistSpecManager:
@staticmethod
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
world_size = old_dist_spec.process_group.size()
world_size = old_dist_spec.process_group.tp_world_size()
if world_size == 1:
return tensor
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
f"collective function, however, we got {tensor.device.type} device and " \
f"{dist.get_backend(old_dist_spec.process_group)} backend"
f"{old_dist_spec.process_group.backend} backend"
gather_dim = old_dist_spec.dims[0]
scatter_dim = dist_spec.dims[0]
@@ -126,7 +128,7 @@ class DistSpecManager:
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size