mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-29 05:26:21 +00:00
[refactor] remove gpc dependency in colotensor's _ops (#1189)
This commit is contained in:
@@ -5,6 +5,7 @@ from contextlib import contextmanager
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
@@ -64,7 +65,7 @@ class DistSpecManager:
|
||||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
idx = dist_spec.process_group.rank()
|
||||
idx = dist_spec.process_group.tp_local_rank()
|
||||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
@@ -91,8 +92,9 @@ class DistSpecManager:
|
||||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
@@ -108,14 +110,14 @@ class DistSpecManager:
|
||||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
world_size = old_dist_spec.process_group.size()
|
||||
world_size = old_dist_spec.process_group.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
|
||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device and " \
|
||||
f"{dist.get_backend(old_dist_spec.process_group)} backend"
|
||||
f"{old_dist_spec.process_group.backend} backend"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
@@ -126,7 +128,7 @@ class DistSpecManager:
|
||||
|
||||
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
|
||||
|
||||
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||
|
Reference in New Issue
Block a user