mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] add runtime pass and numerical test for view handler (#2018)
This commit is contained in:
@@ -37,6 +37,30 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# experimental pass for torch.Tensor.view
|
||||
# Arguments of view op will be divided in the sharded dimensions.
|
||||
for node in nodes:
|
||||
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
# the dict to record comm actions of nodes
|
||||
|
@@ -103,13 +103,18 @@ class ViewGenerator(FollowingStrategyGenerator):
|
||||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||
# the total mesh dim list only has one element, so the shard dim has only one element as well.
|
||||
shard_dim = list(dim_partition_dict_for_input.keys())[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||
# it will gather the input through gather_dim during forward phase.
|
||||
input_comm_action.comm_spec.gather_dim = shard_dim
|
||||
# it will split the input activation grad through shard_dim during backward phase.
|
||||
input_comm_action.comm_spec.shard_dim = shard_dim
|
||||
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
|
@@ -105,6 +105,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
||||
dim_mapping={0: i},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
sharding_strategies.append(strategy_copy)
|
||||
except ShardingNotDivisibleError as e:
|
||||
logger.debug(
|
||||
@@ -194,7 +195,7 @@ class LinearModuleHandler(ModuleHandler):
|
||||
@operator_registry.register(F.linear)
|
||||
class LinearFunctionHandler(NodeHandler):
|
||||
"""
|
||||
A LinearModuleHandler which deals with the sharding strategies for nn.Linear module.
|
||||
A LinearFunctionHandler which deals with the sharding strategies for F.Linear.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
|
Reference in New Issue
Block a user