[tensor]add 1D device mesh (#1492)

This commit is contained in:
YuliangLiu0306
2022-08-25 16:48:12 +08:00
committed by GitHub
parent b8d0e39eaf
commit 4b03c25f85
4 changed files with 66 additions and 13 deletions

View File

@@ -3,6 +3,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
from typing import Dict, List, Optional, Tuple, Union
import torch.distributed as dist
import math
from functools import reduce
@@ -29,9 +30,9 @@ class CommSpec:
Argument:
comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action.
gather_dim(int, optional): The gather_dim of the tensor will be gathered.
shard_dim(int, optional): The shard_dim of the tensor will be sharded.
logical_process_axis(int, optional): The mesh_dim to implement the communication action.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
'''
def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None):
@@ -40,6 +41,11 @@ class CommSpec:
self.gather_dim = gather_dim
self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis
if isinstance(self.logical_process_axis, list):
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh
self.logical_process_axis = 0
else:
self.device_mesh = self.sharding_spec.device_mesh
def __repr__(self):
res_list = ["CommSpec:("]
@@ -70,11 +76,11 @@ class CommSpec:
'''
comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1)
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
return self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
return self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
return self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
@@ -87,15 +93,14 @@ class CommSpec:
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
device_mesh = self.sharding_spec.device_mesh
process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis]
process_groups_list = self.device_mesh.process_groups_dict[self.logical_process_axis]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis])
for _ in range(self.device_mesh.mesh_shape[self.logical_process_axis])
]
tensor = tensor
group = process_group