[tensor] support runtime ShardingSpec apply (#1453)

* [tensor] support runtime ShardingSpec apply

* polish code

* polish code
This commit is contained in:
YuliangLiu0306
2022-08-19 13:39:51 +08:00
committed by GitHub
parent 177d3f5718
commit b73fb7a077
5 changed files with 485 additions and 11 deletions

View File

@@ -3,15 +3,18 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from enum import Enum
from copy import deepcopy
import torch.distributed as dist
import math
from functools import reduce
import operator
from torch.distributed import ReduceOp
class CollectiveCommPattern(Enum):
ALLGATHER = 'all_gather'
ALLTOALL = 'all_to_all'
SHARD = 'shard'
ALLREDUCE = 'all_reduce'
class CommSpec:
@@ -41,7 +44,7 @@ class CommSpec:
def __repr__(self):
res_list = ["CommSpec:("]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
res_list.append(f"comm_pattern:allgather, ")
res_list.append(f"comm_pattern:all_gather, ")
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
@@ -49,15 +52,19 @@ class CommSpec:
res_list.append(f"gather_dim:{self.gather_dim}, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis: {self.logical_process_axis})")
else:
elif self.comm_pattern == CollectiveCommPattern.SHARD:
res_list.append(f"comm_pattern:shard, ")
res_list.append(f"shard_dim:{self.shard_dim}, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
res_list.append(f"comm_pattern:all_reduce, ")
res_list.append(f"logical_process_axis:{self.logical_process_axis})")
return ''.join(res_list)
def get_comm_cost(self):
'''
For all_gather and all2all operation, the formula provided in DeviceMesh with alpha-beta model is used to
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
@@ -66,10 +73,77 @@ class CommSpec:
return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.ALLTOALL:
return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis)
return 0
if self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
return self.sharding_spec.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis)
if self.comm_pattern == CollectiveCommPattern.SHARD:
return 0
raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.")
def covert_spec_to_action(self):
pass
def covert_spec_to_action(self, tensor):
'''
Convert CommSpec into runtime action, implement real collection communication to target tensor.
The collection communication action is directed by the CommSpec.
Argument:
tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks.
'''
device_mesh = self.sharding_spec.device_mesh
process_groups_list = device_mesh.process_groups_dict[self.logical_process_axis]
if self.comm_pattern == CollectiveCommPattern.ALLGATHER:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(self.sharding_spec.device_mesh.mesh_shape[self.logical_process_axis])
]
tensor = tensor
group = process_group
dist.all_gather(tensor_list, tensor, group=group)
tensor.data = torch.cat(tuple(tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.SHARD:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
tensor = tensor
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank())
tensor.data = torch.narrow(tensor, dim, start, length)
elif self.comm_pattern == CollectiveCommPattern.ALLTOALL:
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
new_shape = list(tensor.shape)
new_shape[self.shard_dim] = new_shape[self.shard_dim] // len(rank_list)
new_shape = torch.Size(new_shape)
output_tensor_list = [
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list))
]
dim = self.shard_dim
length = tensor.shape[self.shard_dim] // len(rank_list)
input_tensor_list = [
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list))
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
tensor.data = torch.cat(tuple(output_tensor_list), self.gather_dim)
elif self.comm_pattern == CollectiveCommPattern.ALLREDUCE:
# For the consistency of collective communication operation, we temporally do not
# allow all_reduce two different mesh dimensions in the same time.
# e.g.: MatMul[(R, S01), (S01, R)] -> Partial(R, R),
# all_reduce(Partial, logical_pg=(0, 1)) is NOT allowed, instead
# we need to do this in two steps:
# 1. all_reduce(Partial, logical_pg=1)
# 2. all_reduce(Partial, logical_pg=0)
for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list:
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
tensor.data = tensor
else:
tensor.data = tensor
class ShapeConsistencyManager:
@@ -191,7 +265,7 @@ class ShapeConsistencyManager:
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, R) -> (R, S01) is NOT allowed
# skip (R, S01) -> (S01, R) is NOT allowed
if len(source_spec.dim_partition_dict[b_index]) >= 2:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
@@ -409,7 +483,7 @@ class ShapeConsistencyManager:
self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence)
return (transform_path, comm_action_sequence, total_cost)
temp_sharding_spec = deepcopy(source_spec)
temp_sharding_spec = source_spec
transform_path.append(temp_sharding_spec)
# To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms
while total_steps <= MAX_TRANSFORM_STEPS:
@@ -428,9 +502,9 @@ class ShapeConsistencyManager:
return (transform_path, comm_action_sequence, total_cost)
if spec_difference < best_difference_score:
temp_sharding_spec = deepcopy(sharding_spec)
temp_sharding_spec = sharding_spec
temp_cost = cost
temp_comm_spec = deepcopy(comm_spec)
temp_comm_spec = comm_spec
best_difference_score = spec_difference
transform_path.append(temp_sharding_spec)
@@ -439,3 +513,67 @@ class ShapeConsistencyManager:
total_steps += 1
raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.")
def apply(self, tensor_with_sharding_spec, target_spec):
'''
Apply target_spec to tensor with source sharding spec, the transform path is generated by the
shape_consistency method.
Argument:
tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec.
target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 2))
shape_consistency_manager = ShapeConsistencyManager()
dim_partition_source = {0: [0]}
dim_partition_target = {1: [0]}
# DistSpec:
# shard_sequence: S0,R
# device_mesh_shape: (2, 2)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: R,S0
# device_mesh_shape: (2, 2)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1)
sharded_tensor_1 = torch.ones(2, 1)
# tensor([[0., 1.],
# [0., 1.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
if rank in (2, 3):
sharded_tensor_0 = torch.ones(2, 1) * 2
sharded_tensor_1 = torch.ones(2, 1) * 3
# tensor([[2., 3.],
# [2., 3.]])
tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda()
tensor_to_comm.sharding_spec = sharding_spec_source
shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target)
print(tensor_to_comm)
Output in rank0 and rank2:
tensor([[0.],
[0.],
[2.],
[2.]])
Output in rank1 and rank3:
tensor([[1.],
[1.],
[3.],
[3.]])
'''
_, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec)
for comm_spec in comm_action_sequence:
comm_spec.covert_spec_to_action(tensor_with_sharding_spec)
tensor_with_sharding_spec.sharding_spec = target_spec