[autoparallel] add experimental permute handler (#2029)

This commit is contained in:
YuliangLiu0306
2022-11-27 20:26:52 +08:00
committed by GitHub
parent 95c4532fff
commit 81330b0352
11 changed files with 657 additions and 37 deletions

View File

@@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for node in nodes:
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
new_args = []
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
node.args = tuple(new_args)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes
@@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
This pass will process node args to adapt the distributed tensor layout.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
# skip the placeholder node added in _solution_annotation pass
if not hasattr(node, 'sharding_spec'):
continue
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
new_args = []
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
# process the node with (input, *shape) style args
if method in (torch.Tensor.view, torch.Tensor.reshape):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[dim + 1] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
node.args = tuple(new_args)
elif node.op == 'call_function':
target = node.target
# process the node with (input, torch.Size) style args
if target in (torch.reshape,):
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, (tuple, list)):
new_args.append(list(arg._meta_data))
else:
new_args.append(arg)
else:
assert isinstance(
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
new_args.append(list(arg))
for dim, shard_dims in output_dim_partition_dict.items():
# we will skip the dim with -1 value
if new_args[1][dim] == -1:
continue
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[1][dim] //= total_shard_size
node.args = tuple(new_args)
return gm
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
"""
Apply the sharding action to the module parameters and buffers following the
instructions of solver solution.
@@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
gm, solution)
gm = _node_args_converting(gm, device_mesh)
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
# gm = implicit_comm_action_apply(gm)
gm = _module_params_sharding(gm, device_mesh)