mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[autoparallel] add experimental permute handler (#2029)
This commit is contained in:
@@ -37,30 +37,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
origin_node_sharding_spec_dict[node_index] = strategies_vector[strategy_index].get_sharding_spec_by_name(
|
||||
str(node))
|
||||
|
||||
# experimental pass for torch.Tensor.view
|
||||
# Arguments of view op will be divided in the sharded dimensions.
|
||||
for node in nodes:
|
||||
if node.op == 'call_method' and getattr(node.args[0]._meta_data.__class__, node.target) in (torch.Tensor.view,):
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
new_args = []
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
# the dict to get input sharding specs of user node
|
||||
sharding_spec_convert_dict = {}
|
||||
# the dict to record comm actions of nodes
|
||||
@@ -113,7 +89,74 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
|
||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||
|
||||
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
|
||||
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
This pass will process node args to adapt the distributed tensor layout.
|
||||
"""
|
||||
mod_graph = gm.graph
|
||||
nodes = tuple(mod_graph.nodes)
|
||||
|
||||
for node in nodes:
|
||||
# skip the placeholder node added in _solution_annotation pass
|
||||
if not hasattr(node, 'sharding_spec'):
|
||||
continue
|
||||
output_dim_partition_dict = node.sharding_spec.dim_partition_dict
|
||||
device_mesh = node.sharding_spec.device_mesh
|
||||
new_args = []
|
||||
|
||||
if node.op == 'call_method':
|
||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||
# process the node with (input, *shape) style args
|
||||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, int):
|
||||
new_args.append(arg._meta_data)
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(arg, int), 'The argument in view node should be either type of Node or int.'
|
||||
new_args.append(arg)
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[dim + 1] == -1:
|
||||
continue
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[dim + 1] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
elif node.op == 'call_function':
|
||||
target = node.target
|
||||
# process the node with (input, torch.Size) style args
|
||||
if target in (torch.reshape,):
|
||||
for arg in node.args:
|
||||
if isinstance(arg, Node):
|
||||
if isinstance(arg._meta_data, (tuple, list)):
|
||||
new_args.append(list(arg._meta_data))
|
||||
else:
|
||||
new_args.append(arg)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
|
||||
new_args.append(list(arg))
|
||||
|
||||
for dim, shard_dims in output_dim_partition_dict.items():
|
||||
# we will skip the dim with -1 value
|
||||
if new_args[1][dim] == -1:
|
||||
continue
|
||||
total_shard_size = 1
|
||||
for shard_dim in shard_dims:
|
||||
total_shard_size *= device_mesh.shape[shard_dim]
|
||||
new_args[1][dim] //= total_shard_size
|
||||
node.args = tuple(new_args)
|
||||
|
||||
return gm
|
||||
|
||||
|
||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
"""
|
||||
Apply the sharding action to the module parameters and buffers following the
|
||||
instructions of solver solution.
|
||||
@@ -216,6 +259,7 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
||||
def runtime_preparation_pass(gm: torch.fx.GraphModule, solution: List[int], device_mesh: DeviceMesh):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
||||
gm, solution)
|
||||
gm = _node_args_converting(gm, device_mesh)
|
||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||
# gm = implicit_comm_action_apply(gm)
|
||||
gm = _module_params_sharding(gm, device_mesh)
|
||||
|
Reference in New Issue
Block a user