[autoparallel] mix gather (#1977)

* Add mix-gather

* Add comments

* Add comments

* Polish comments

* Change the global rank assumption

* Add tests

* Add two-step tests

* Fix 10 and 01

* Skip test becasue the number of GPUs
This commit is contained in:
Genghan Zhang
2022-11-23 21:49:17 +08:00
committed by GitHub
parent 7242bffc5f
commit d655eea515
5 changed files with 617 additions and 4 deletions

View File

@@ -7,7 +7,7 @@ import torch
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
from .comm_spec import *
@@ -328,6 +328,59 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
pass
return valid_spec_dict
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
'''
S0S1 -> RR
S1S0 -> RR
S01R -> RR
RS01 -> RR
'''
valid_spec_dict = {}
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
tensor_dims = len(source_spec.entire_shape)
for f_index in range(tensor_dims - 1):
for b_index in range(f_index + 1, tensor_dims):
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
continue
else:
if f_index in source_spec.dim_partition_dict:
# skip (S10, R) -> (R, R)
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
continue
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
else:
f_target_pair = (f_index, [])
if b_index in source_spec.dim_partition_dict:
# skip (R, S10) -> (R, R)
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
continue
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
else:
b_target_pair = (b_index, [])
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
comm_spec = CommSpec(comm_pathern,
sharding_spec=source_spec,
gather_dim=gather_dim,
logical_process_axis=logical_process_axes,
forward_only=self.forward_only,
mix_gather=True)
cost_dict = comm_spec.get_comm_cost()
new_dim_partition_dict = {}
# generate new sharding spec
try:
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
source_spec.entire_shape,
dim_partition_dict=new_dim_partition_dict)
for phase, cost in cost_dict.items():
cost_dict[phase] = cost + orig_cost_dict[phase]
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
except ShardingSpecException:
pass
return valid_spec_dict
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
'''
Get all valid sharding specs from source_spec with one step transform, and