mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] mix gather (#1977)
* Add mix-gather * Add comments * Add comments * Polish comments * Change the global rank assumption * Add tests * Add two-step tests * Fix 10 and 01 * Skip test becasue the number of GPUs
This commit is contained in:
@@ -7,7 +7,7 @@ import torch
|
||||
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator
|
||||
|
||||
from .comm_spec import *
|
||||
|
||||
@@ -328,6 +328,59 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
|
||||
pass
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_mix_gather_spec(self, source_spec: ShardingSpec,
|
||||
orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
S0S1 -> RR
|
||||
S1S0 -> RR
|
||||
S01R -> RR
|
||||
RS01 -> RR
|
||||
'''
|
||||
valid_spec_dict = {}
|
||||
comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD
|
||||
tensor_dims = len(source_spec.entire_shape)
|
||||
for f_index in range(tensor_dims - 1):
|
||||
for b_index in range(f_index + 1, tensor_dims):
|
||||
if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict):
|
||||
continue
|
||||
else:
|
||||
if f_index in source_spec.dim_partition_dict:
|
||||
# skip (S10, R) -> (R, R)
|
||||
if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]:
|
||||
continue
|
||||
f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index]))
|
||||
else:
|
||||
f_target_pair = (f_index, [])
|
||||
if b_index in source_spec.dim_partition_dict:
|
||||
# skip (R, S10) -> (R, R)
|
||||
if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]:
|
||||
continue
|
||||
b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index]))
|
||||
else:
|
||||
b_target_pair = (b_index, [])
|
||||
|
||||
gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair)
|
||||
comm_spec = CommSpec(comm_pathern,
|
||||
sharding_spec=source_spec,
|
||||
gather_dim=gather_dim,
|
||||
logical_process_axis=logical_process_axes,
|
||||
forward_only=self.forward_only,
|
||||
mix_gather=True)
|
||||
cost_dict = comm_spec.get_comm_cost()
|
||||
new_dim_partition_dict = {}
|
||||
# generate new sharding spec
|
||||
try:
|
||||
new_sharding_spec = ShardingSpec(source_spec.device_mesh,
|
||||
source_spec.entire_shape,
|
||||
dim_partition_dict=new_dim_partition_dict)
|
||||
for phase, cost in cost_dict.items():
|
||||
cost_dict[phase] = cost + orig_cost_dict[phase]
|
||||
valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict)
|
||||
except ShardingSpecException:
|
||||
pass
|
||||
|
||||
return valid_spec_dict
|
||||
|
||||
def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]:
|
||||
'''
|
||||
Get all valid sharding specs from source_spec with one step transform, and
|
||||
|
Reference in New Issue
Block a user