[autoparallel] add runtime pass and numerical test for view handler (#2018)

This commit is contained in:
YuliangLiu0306
2022-11-25 15:50:16 +08:00
committed by GitHub
parent bb6245612d
commit ea0f6b8df9
5 changed files with 251 additions and 50 deletions

View File

@@ -37,6 +37,30 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
str(node))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for node in nodes:
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
device_mesh = node.sharding_spec.device_mesh
new_args = []
for arg in node.args:
if isinstance(arg, Node):
if isinstance(arg._meta_data, int):
new_args.append(arg._meta_data)
else:
new_args.append(arg)
else:
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
new_args.append(arg)
for dim, shard_dims in output_dim_partition_dict.items():
total_shard_size = 1
for shard_dim in shard_dims:
total_shard_size *= device_mesh.shape[shard_dim]
new_args[dim + 1] //= total_shard_size
node.args = tuple(new_args)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict = {}
# the dict to record comm actions of nodes