[autoparallel] add runtime pass and numerical test for view handler (#2018)

This commit is contained in:
YuliangLiu0306
2022-11-25 15:50:16 +08:00
committed by GitHub
parent bb6245612d
commit ea0f6b8df9
5 changed files with 251 additions and 50 deletions

View File

@@ -103,13 +103,18 @@ class ViewGenerator(FollowingStrategyGenerator):
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim = list(dim_partition_dict_for_input.keys())[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
# it will gather the input through gather_dim during forward phase.
input_comm_action.comm_spec.gather_dim = shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action.comm_spec.shard_dim = shard_dim
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]

View File

@@ -105,6 +105,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
@@ -194,7 +195,7 @@ class LinearModuleHandler(ModuleHandler):
@operator_registry.register(F.linear)
class LinearFunctionHandler(NodeHandler):
"""
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]: