mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-09 08:58:12 +00:00
[autoparallel] refactor runtime pass (#2644)
* [autoparallel] refactor runtime pass * add unit test * polish
This commit is contained in:
parent
89f8975fb8
commit
cb2c6a2415
@ -6,3 +6,8 @@ OUTPUT_SAVED_MOD = [
|
|||||||
torch.nn.ReLU,
|
torch.nn.ReLU,
|
||||||
torch.nn.Softmax,
|
torch.nn.Softmax,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# SHAPE_ARGUMENT_OPS contains node with (input, *shape) style args.
|
||||||
|
# This list could be extended if any other method has the same
|
||||||
|
# argument style as view and reshape.
|
||||||
|
SHAPE_ARGUMENT_OPS = [torch.Tensor.view, torch.Tensor.reshape, torch.reshape]
|
||||||
|
@ -19,6 +19,8 @@ from colossalai.tensor.comm_spec import _all_reduce
|
|||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
from .constants import SHAPE_ARGUMENT_OPS
|
||||||
|
|
||||||
shape_consistency_manager = ShapeConsistencyManager()
|
shape_consistency_manager = ShapeConsistencyManager()
|
||||||
|
|
||||||
|
|
||||||
@ -51,23 +53,16 @@ def size_processing(size: Union[int, torch.Size],
|
|||||||
return size
|
return size
|
||||||
|
|
||||||
|
|
||||||
def _solution_annotatation(gm: torch.fx.GraphModule,
|
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||||
solution: List[int],
|
strategies_constructor: StrategiesConstructor):
|
||||||
strategies_constructor: StrategiesConstructor = None):
|
|
||||||
"""
|
"""
|
||||||
This method is used to stick the solution strategy to the nodes and add the information
|
This method is used to stick the solution strategy to the nodes and add the information
|
||||||
required in runtime into graph as placeholder nodes.
|
required in runtime into graph as placeholder nodes.
|
||||||
"""
|
"""
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
# TODO: In future PR, strategies_constructor should be a required argument,
|
|
||||||
# instead of optional argument. This is because we don't need to consider nodes with
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
||||||
# no strategy in runtime preparation pass.
|
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
||||||
if strategies_constructor is not None:
|
|
||||||
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
|
||||||
no_strategy_nodes = strategies_constructor.no_strategy_nodes
|
|
||||||
else:
|
|
||||||
nodes = tuple(mod_graph.nodes)
|
|
||||||
no_strategy_nodes = []
|
|
||||||
|
|
||||||
# the dict to get origin sharding spec of node
|
# the dict to get origin sharding spec of node
|
||||||
origin_node_sharding_spec_dict = {}
|
origin_node_sharding_spec_dict = {}
|
||||||
@ -97,6 +92,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|||||||
target_sharding_specs.append(target_sharding_spec)
|
target_sharding_specs.append(target_sharding_spec)
|
||||||
sharding_spec_convert_dict[index] = target_sharding_specs
|
sharding_spec_convert_dict[index] = target_sharding_specs
|
||||||
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
setattr(node, 'target_sharding_specs', target_sharding_specs)
|
||||||
|
|
||||||
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
# the get_attr node strategy is kind of pending strategy, which means we will change it
|
||||||
# to the same strategy of the user node.
|
# to the same strategy of the user node.
|
||||||
if node.op == 'get_attr':
|
if node.op == 'get_attr':
|
||||||
@ -134,7 +130,7 @@ def _solution_annotatation(gm: torch.fx.GraphModule,
|
|||||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||||
|
|
||||||
|
|
||||||
def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
"""
|
"""
|
||||||
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
In the auto parallel system, tensors may get shard on different devices, so the size of tensors
|
||||||
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
need to be converted to the size of original tensor and managed by the users, such as torch.view,
|
||||||
@ -145,6 +141,80 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
nodes = tuple(mod_graph.nodes)
|
nodes = tuple(mod_graph.nodes)
|
||||||
node_pairs = {}
|
node_pairs = {}
|
||||||
|
|
||||||
|
# DeviceMesh information instructs the scaling of the size value
|
||||||
|
device_mesh_info = {}
|
||||||
|
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||||
|
device_mesh_info[dim] = dim_size
|
||||||
|
|
||||||
|
def _extract_target_dim(node):
|
||||||
|
'''
|
||||||
|
A helper function to etract the target dimension from size node.
|
||||||
|
There are two usages of torch.Tensor.size:
|
||||||
|
1. tensor.size()
|
||||||
|
2. tensor.size(dim)
|
||||||
|
|
||||||
|
If a target_dim is assigned, then the output will be in type of int, instead of torch.Size.
|
||||||
|
Otherwise, the output will be in type of torch.Size and this function will return None.
|
||||||
|
'''
|
||||||
|
target_dim = None
|
||||||
|
if len(node.args) > 1:
|
||||||
|
target_dim = node.args[1]
|
||||||
|
if target_dim < 0:
|
||||||
|
target_dim += node.args[0]._meta_data.dim()
|
||||||
|
return target_dim
|
||||||
|
|
||||||
|
def _post_processing(node, size_processing_node):
|
||||||
|
'''
|
||||||
|
This function is used to process the dependency between the size node and its users after
|
||||||
|
inserting the size_process_node.
|
||||||
|
'''
|
||||||
|
# store original node and processing node pair in node_pairs dictioanry
|
||||||
|
# It will be used to replace the original node with processing node in slice object
|
||||||
|
node_pairs[node] = size_processing_node
|
||||||
|
size_processing_node._meta_data = node._meta_data
|
||||||
|
if 'activation_checkpoint' in node.meta:
|
||||||
|
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
||||||
|
|
||||||
|
user_list = list(node.users.keys())
|
||||||
|
for user in user_list:
|
||||||
|
if user == size_processing_node:
|
||||||
|
continue
|
||||||
|
new_args = list(user.args)
|
||||||
|
new_kwargs = dict(user.kwargs)
|
||||||
|
# the origin node may be a positional argument or key word argument of user node
|
||||||
|
if node in new_args:
|
||||||
|
# substitute the origin node with size_processing_node
|
||||||
|
new_args[new_args.index(node)] = size_processing_node
|
||||||
|
user.args = tuple(new_args)
|
||||||
|
elif str(node) in new_kwargs:
|
||||||
|
# substitute the origin node with size_processing_node
|
||||||
|
new_kwargs[str(node)] = size_processing_node
|
||||||
|
user.kwargs = new_kwargs
|
||||||
|
|
||||||
|
def _update_slice_object_args(slice_object):
|
||||||
|
'''
|
||||||
|
This function is used to update the slice object argument list.
|
||||||
|
If the slice object contains the Node argument, then the size node will be replaced with
|
||||||
|
'''
|
||||||
|
if isinstance(slice_object, slice):
|
||||||
|
start = slice_object.start
|
||||||
|
stop = slice_object.stop
|
||||||
|
step = slice_object.step
|
||||||
|
if start in node_pairs:
|
||||||
|
start = node_pairs[start]
|
||||||
|
if stop in node_pairs:
|
||||||
|
stop = node_pairs[stop]
|
||||||
|
if step in node_pairs:
|
||||||
|
step = node_pairs[step]
|
||||||
|
return slice(start, stop, step)
|
||||||
|
elif isinstance(slice_object, int):
|
||||||
|
if slice_object in node_pairs:
|
||||||
|
return node_pairs[slice_object]
|
||||||
|
else:
|
||||||
|
return slice_object
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported slice object type: {type(slice_object)}")
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
|
|
||||||
if node.op == 'call_method' and node.target == 'size':
|
if node.op == 'call_method' and node.target == 'size':
|
||||||
@ -154,49 +224,15 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
sharding_spec = node.args[0].sharding_spec
|
sharding_spec = node.args[0].sharding_spec
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
|
||||||
# there are two usages of torch.Tensor.size:
|
target_dim = _extract_target_dim(node)
|
||||||
# tensor.size()
|
|
||||||
# tensor.size(dim)
|
|
||||||
# if a target_dim is assigned, then the output will be
|
|
||||||
# in type of int, instead of torch.Size
|
|
||||||
target_dim = None
|
|
||||||
if len(node.args) > 1:
|
|
||||||
target_dim = node.args[1]
|
|
||||||
if target_dim < 0:
|
|
||||||
target_dim += node.args[0]._meta_data.dim()
|
|
||||||
|
|
||||||
# DeviceMesh information instructs the scaling of the size value
|
|
||||||
device_mesh_info = {}
|
|
||||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
|
||||||
device_mesh_info[dim] = dim_size
|
|
||||||
|
|
||||||
|
# insert size_processing node
|
||||||
with mod_graph.inserting_after(node):
|
with mod_graph.inserting_after(node):
|
||||||
size_processing_node = mod_graph.create_node('call_function',
|
size_processing_node = mod_graph.create_node('call_function',
|
||||||
size_processing,
|
size_processing,
|
||||||
args=(node, dim_partition_dict, device_mesh_info,
|
args=(node, dim_partition_dict, device_mesh_info,
|
||||||
target_dim, node.name))
|
target_dim, node.name))
|
||||||
# store original node and processing node pair in node_pairs dictioanry
|
_post_processing(node, size_processing_node)
|
||||||
# It will be used to replace the original node with processing node in slice object
|
|
||||||
node_pairs[node] = size_processing_node
|
|
||||||
size_processing_node._meta_data = node._meta_data
|
|
||||||
if 'activation_checkpoint' in node.meta:
|
|
||||||
size_processing_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
|
|
||||||
|
|
||||||
user_list = list(node.users.keys())
|
|
||||||
for user in user_list:
|
|
||||||
if user == size_processing_node:
|
|
||||||
continue
|
|
||||||
new_args = list(user.args)
|
|
||||||
new_kwargs = dict(user.kwargs)
|
|
||||||
# the origin node may be a positional argument or key word argument of user node
|
|
||||||
if node in new_args:
|
|
||||||
# substitute the origin node with size_processing_node
|
|
||||||
new_args[new_args.index(node)] = size_processing_node
|
|
||||||
user.args = tuple(new_args)
|
|
||||||
elif str(node) in new_kwargs:
|
|
||||||
# substitute the origin node with size_processing_node
|
|
||||||
new_kwargs[str(node)] = size_processing_node
|
|
||||||
user.kwargs = new_kwargs
|
|
||||||
|
|
||||||
if node.op == 'call_function' and node.target == operator.getitem:
|
if node.op == 'call_function' and node.target == operator.getitem:
|
||||||
|
|
||||||
@ -217,14 +253,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
# In this pass, we need process the last two cases because
|
# In this pass, we need process the last two cases because
|
||||||
# node arguments may potentially appear in these cases.
|
# node arguments may potentially appear in these cases.
|
||||||
if isinstance(getitem_index, slice):
|
if isinstance(getitem_index, slice):
|
||||||
new_start, new_stop, new_step = getitem_index.start, getitem_index.stop, getitem_index.step
|
new_slice_item = _update_slice_object_args(getitem_index)
|
||||||
if getitem_index.start in node_pairs:
|
|
||||||
new_start = node_pairs[getitem_index.start]
|
|
||||||
elif getitem_index.stop in node_pairs:
|
|
||||||
new_stop = node_pairs[getitem_index.stop]
|
|
||||||
elif getitem_index.step in node_pairs:
|
|
||||||
new_step = node_pairs[getitem_index.step]
|
|
||||||
new_slice_item = slice(new_start, new_stop, new_step)
|
|
||||||
new_args = (node.args[0], new_slice_item)
|
new_args = (node.args[0], new_slice_item)
|
||||||
node.args = new_args
|
node.args = new_args
|
||||||
|
|
||||||
@ -237,16 +266,7 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
if slice_item is None:
|
if slice_item is None:
|
||||||
new_slice_items.append(None)
|
new_slice_items.append(None)
|
||||||
continue
|
continue
|
||||||
|
new_slice_item = _update_slice_object_args(slice_item)
|
||||||
new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
|
|
||||||
|
|
||||||
if slice_item.start in node_pairs:
|
|
||||||
new_start = node_pairs[slice_item.start]
|
|
||||||
elif slice_item.stop in node_pairs:
|
|
||||||
new_stop = node_pairs[slice_item.stop]
|
|
||||||
elif slice_item.step in node_pairs:
|
|
||||||
new_step = node_pairs[slice_item.step]
|
|
||||||
new_slice_item = slice(new_start, new_stop, new_step)
|
|
||||||
new_slice_items.append(new_slice_item)
|
new_slice_items.append(new_slice_item)
|
||||||
|
|
||||||
new_args = (node.args[0], tuple(new_slice_items))
|
new_args = (node.args[0], tuple(new_slice_items))
|
||||||
@ -255,104 +275,109 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
|||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def _node_args_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
def node_args_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||||
"""
|
"""
|
||||||
This pass will process node args to adapt the distributed tensor layout.
|
This pass will process node args to adapt the distributed tensor layout.
|
||||||
"""
|
"""
|
||||||
mod_graph = gm.graph
|
mod_graph = gm.graph
|
||||||
nodes = tuple(mod_graph.nodes)
|
nodes = tuple(mod_graph.nodes)
|
||||||
|
|
||||||
|
def _extract_info_from_sharding_spec(sharding_spec):
|
||||||
|
'''
|
||||||
|
This function is used to extract the dim_partition_dict and device_mesh from
|
||||||
|
sharding spec instance or a list of sharding spec.
|
||||||
|
'''
|
||||||
|
if isinstance(sharding_spec, ShardingSpec):
|
||||||
|
dim_partition_dict = sharding_spec.dim_partition_dict
|
||||||
|
device_mesh = sharding_spec.device_mesh
|
||||||
|
return dim_partition_dict, device_mesh
|
||||||
|
if sharding_spec is None:
|
||||||
|
return None, None
|
||||||
|
assert isinstance(sharding_spec,
|
||||||
|
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
||||||
|
|
||||||
|
device_mesh = sharding_spec[0].device_mesh
|
||||||
|
dim_partition_dict = []
|
||||||
|
for element in sharding_spec:
|
||||||
|
dim_partition_dict.append(_extract_info_from_sharding_spec(element))
|
||||||
|
return dim_partition_dict, sharding_spec
|
||||||
|
|
||||||
|
def _process_node_arguments(node):
|
||||||
|
new_args = []
|
||||||
|
for arg in node.args:
|
||||||
|
# There are two args style:
|
||||||
|
# 1. (input, *shape)
|
||||||
|
# 2. (input, shape)
|
||||||
|
# We will extract the elements from shape and add them into the new_args
|
||||||
|
# Finally, the args style of new_args will be unified to (input, *shape)
|
||||||
|
if isinstance(arg, Node):
|
||||||
|
if isinstance(arg._meta_data, (tuple, list)):
|
||||||
|
new_args.extend(arg._meta_data)
|
||||||
|
elif isinstance(arg._meta_data, int):
|
||||||
|
new_args.append(arg._meta_data)
|
||||||
|
else:
|
||||||
|
new_args.append(arg)
|
||||||
|
else:
|
||||||
|
assert isinstance(arg,
|
||||||
|
(int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
||||||
|
if isinstance(arg, (tuple, list)):
|
||||||
|
new_args.extend(arg)
|
||||||
|
else:
|
||||||
|
new_args.append(arg)
|
||||||
|
return new_args
|
||||||
|
|
||||||
|
def _scale_args_adapt_sharding_spec(dim_partition_dict, device_mesh, node):
|
||||||
|
new_args = _process_node_arguments(node)
|
||||||
|
if node.op == 'call_method':
|
||||||
|
args_to_process = list(new_args[1:])
|
||||||
|
else:
|
||||||
|
args_to_process = list(new_args)
|
||||||
|
for dim, shard_dims in dim_partition_dict.items():
|
||||||
|
total_shard_size = 1
|
||||||
|
for shard_dim in shard_dims:
|
||||||
|
total_shard_size *= device_mesh.shape[shard_dim]
|
||||||
|
|
||||||
|
# we will skip the dim with -1 value
|
||||||
|
if args_to_process[dim] == -1:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# TODO: add assertion here to make sure the dim size is divisible by total_shard_size
|
||||||
|
args_to_process[dim] //= total_shard_size
|
||||||
|
|
||||||
|
args_to_process = tuple(args_to_process)
|
||||||
|
|
||||||
|
if node.op == 'call_method':
|
||||||
|
new_args = (new_args[0],) + args_to_process
|
||||||
|
else:
|
||||||
|
new_args = args_to_process
|
||||||
|
|
||||||
|
node.args = new_args
|
||||||
|
|
||||||
|
def _filter_node_with_shape_args(node):
|
||||||
|
if node.op == 'call_method':
|
||||||
|
target = getattr(node.args[0]._meta_data.__class__, node.target)
|
||||||
|
elif node.op == 'call_function':
|
||||||
|
target = node.target
|
||||||
|
else:
|
||||||
|
target = None
|
||||||
|
|
||||||
|
if target in SHAPE_ARGUMENT_OPS:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
# skip the placeholder node added in _solution_annotation pass
|
# skip the placeholder node added in _solution_annotation pass
|
||||||
if not hasattr(node, 'sharding_spec'):
|
if not hasattr(node, 'sharding_spec'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def _process_sharding_spec(sharding_spec):
|
output_dim_partition_dict, device_mesh = _extract_info_from_sharding_spec(node.sharding_spec)
|
||||||
if isinstance(sharding_spec, ShardingSpec):
|
if _filter_node_with_shape_args(node):
|
||||||
dim_partition_dict = sharding_spec.dim_partition_dict
|
_scale_args_adapt_sharding_spec(output_dim_partition_dict, device_mesh, node)
|
||||||
device_mesh = sharding_spec.device_mesh
|
|
||||||
return dim_partition_dict, device_mesh
|
|
||||||
if sharding_spec is None:
|
|
||||||
return None, None
|
|
||||||
assert isinstance(sharding_spec,
|
|
||||||
(tuple, list)), 'sharding_spec should be type of ShardingSpec, tuple, list or None'
|
|
||||||
|
|
||||||
device_mesh = sharding_spec[0].device_mesh
|
|
||||||
dim_partition_dict = []
|
|
||||||
for element in sharding_spec:
|
|
||||||
dim_partition_dict.append(_process_sharding_spec(element))
|
|
||||||
return dim_partition_dict, sharding_spec
|
|
||||||
|
|
||||||
output_dim_partition_dict, device_mesh = _process_sharding_spec(node.sharding_spec)
|
|
||||||
new_args = []
|
|
||||||
|
|
||||||
if node.op == 'call_method':
|
|
||||||
method = getattr(node.args[0]._meta_data.__class__, node.target)
|
|
||||||
# process the node with (input, *shape) style args
|
|
||||||
if method in (torch.Tensor.view, torch.Tensor.reshape):
|
|
||||||
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, Node):
|
|
||||||
if isinstance(arg._meta_data, (int, tuple, list)):
|
|
||||||
new_args.append(arg._meta_data)
|
|
||||||
else:
|
|
||||||
new_args.append(arg)
|
|
||||||
else:
|
|
||||||
assert isinstance(
|
|
||||||
arg, (int, tuple, list)), 'The argument in view node should be either type of Node or int.'
|
|
||||||
new_args.append(arg)
|
|
||||||
|
|
||||||
for dim, shard_dims in output_dim_partition_dict.items():
|
|
||||||
total_shard_size = 1
|
|
||||||
for shard_dim in shard_dims:
|
|
||||||
total_shard_size *= device_mesh.shape[shard_dim]
|
|
||||||
# There are two ways to use torch.view:
|
|
||||||
# 1. torch.view(input, *shape)
|
|
||||||
# 2. torch.view(input, shape)
|
|
||||||
if isinstance(new_args[1], int):
|
|
||||||
# we will skip the dim with -1 value
|
|
||||||
if new_args[dim + 1] == -1:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
new_args[dim + 1] //= total_shard_size
|
|
||||||
else:
|
|
||||||
new_args[1] = list(new_args[1])
|
|
||||||
# we will skip the dim with -1 value
|
|
||||||
if new_args[1][dim] == -1:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
new_args[1][dim] //= total_shard_size
|
|
||||||
node.args = tuple(new_args)
|
|
||||||
|
|
||||||
elif node.op == 'call_function':
|
|
||||||
target = node.target
|
|
||||||
# process the node with (input, torch.Size) style args
|
|
||||||
if target in (torch.reshape,):
|
|
||||||
for arg in node.args:
|
|
||||||
if isinstance(arg, Node):
|
|
||||||
if isinstance(arg._meta_data, (tuple, list)):
|
|
||||||
new_args.append(list(arg._meta_data))
|
|
||||||
else:
|
|
||||||
new_args.append(arg)
|
|
||||||
else:
|
|
||||||
assert isinstance(
|
|
||||||
arg, (tuple, list)), 'The argument in reshape node should be either type of Node or tuple.'
|
|
||||||
new_args.append(list(arg))
|
|
||||||
|
|
||||||
for dim, shard_dims in output_dim_partition_dict.items():
|
|
||||||
# we will skip the dim with -1 value
|
|
||||||
if new_args[1][dim] == -1:
|
|
||||||
continue
|
|
||||||
total_shard_size = 1
|
|
||||||
for shard_dim in shard_dims:
|
|
||||||
total_shard_size *= device_mesh.shape[shard_dim]
|
|
||||||
new_args[1][dim] //= total_shard_size
|
|
||||||
node.args = tuple(new_args)
|
|
||||||
|
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
|
def module_params_sharding_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, overlap=False):
|
||||||
"""
|
"""
|
||||||
Apply the sharding action to the module parameters and buffers following the
|
Apply the sharding action to the module parameters and buffers following the
|
||||||
instructions of solver solution.
|
instructions of solver solution.
|
||||||
@ -361,6 +386,49 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||||||
nodes = tuple(mod_graph.nodes)
|
nodes = tuple(mod_graph.nodes)
|
||||||
# This stream is created for overlaping the communication and computation.
|
# This stream is created for overlaping the communication and computation.
|
||||||
reduction_stream = torch.cuda.Stream()
|
reduction_stream = torch.cuda.Stream()
|
||||||
|
|
||||||
|
def _add_hook_for_grad_communication(node, param):
|
||||||
|
|
||||||
|
comm_actions = node.best_strategy.communication_actions
|
||||||
|
|
||||||
|
def _filter_param_to_hook(node, op_data, comm_action):
|
||||||
|
if node.op == 'call_module' and op_data.type == OperationDataType.PARAM and op_data.name == param.name and comm_action.comm_type == CommType.HOOK:
|
||||||
|
return True
|
||||||
|
if node.op == 'get_attr' and isinstance(
|
||||||
|
node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
for operation_data, comm_action in comm_actions.items():
|
||||||
|
comm_spec_to_use = comm_action.comm_spec
|
||||||
|
# register hook to the parameters
|
||||||
|
if _filter_param_to_hook(node, operation_data, comm_action):
|
||||||
|
|
||||||
|
def wrapper(param, comm_spec, stream, overlap):
|
||||||
|
|
||||||
|
def hook_fn(grad):
|
||||||
|
if overlap:
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
_all_reduce(grad, comm_spec, async_op=True)
|
||||||
|
else:
|
||||||
|
_all_reduce(grad, comm_spec, async_op=False)
|
||||||
|
|
||||||
|
param.register_hook(hook_fn)
|
||||||
|
|
||||||
|
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
|
||||||
|
|
||||||
|
def _shard_param(param, target_sharding_spec):
|
||||||
|
# apply the sharding spec of parameters
|
||||||
|
if target_sharding_spec.dim_partition_dict != {}:
|
||||||
|
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
||||||
|
setattr(param, 'sharding_spec', origin_sharding_spec)
|
||||||
|
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||||
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
|
param = torch.nn.Parameter(
|
||||||
|
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||||
|
target_sharding_spec).detach().clone())
|
||||||
|
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.op == 'call_module':
|
if node.op == 'call_module':
|
||||||
target_module = node.graph.owning_module.get_submodule(node.target)
|
target_module = node.graph.owning_module.get_submodule(node.target)
|
||||||
@ -370,36 +438,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||||||
setattr(target_module, 'processed', True)
|
setattr(target_module, 'processed', True)
|
||||||
for name, param in target_module.named_parameters():
|
for name, param in target_module.named_parameters():
|
||||||
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
|
||||||
# apply the sharding spec of parameters
|
_shard_param(param, target_sharding_spec)
|
||||||
if target_sharding_spec.dim_partition_dict != {}:
|
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
|
|
||||||
setattr(param, 'sharding_spec', origin_sharding_spec)
|
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
|
||||||
# we could use .data here, because all the operations just happen before the real training
|
|
||||||
# loop, so we don't need to track these operations in the autograd graph.
|
|
||||||
param = torch.nn.Parameter(
|
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
|
||||||
target_sharding_spec).detach().clone())
|
|
||||||
|
|
||||||
setattr(target_module, name, param)
|
setattr(target_module, name, param)
|
||||||
comm_actions = node.best_strategy.communication_actions
|
_add_hook_for_grad_communication(node, param)
|
||||||
for operation_data, comm_action in comm_actions.items():
|
|
||||||
comm_spec_to_use = comm_action.comm_spec
|
|
||||||
# register hook to the parameters
|
|
||||||
if operation_data.type == OperationDataType.PARAM and operation_data.name == name and comm_action.comm_type == CommType.HOOK:
|
|
||||||
|
|
||||||
def wrapper(param, comm_spec, stream, overlap):
|
|
||||||
|
|
||||||
def hook_fn(grad):
|
|
||||||
if overlap:
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
_all_reduce(grad, comm_spec, async_op=True)
|
|
||||||
else:
|
|
||||||
_all_reduce(grad, comm_spec, async_op=False)
|
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
|
||||||
|
|
||||||
wrapper(param, comm_spec_to_use, reduction_stream, overlap=overlap)
|
|
||||||
|
|
||||||
sharded_buffer_dict = {}
|
sharded_buffer_dict = {}
|
||||||
# apply the sharding spec of buffers
|
# apply the sharding spec of buffers
|
||||||
@ -427,37 +469,12 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
|||||||
target = getattr(target_module, atoms[-1])
|
target = getattr(target_module, atoms[-1])
|
||||||
|
|
||||||
target_sharding_spec = node.sharding_spec
|
target_sharding_spec = node.sharding_spec
|
||||||
if target_sharding_spec.dim_partition_dict != {}:
|
_shard_param(target, target_sharding_spec)
|
||||||
origin_sharding_spec = ShardingSpec(device_mesh, target.shape, {})
|
|
||||||
setattr(target, 'sharding_spec', origin_sharding_spec)
|
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
|
||||||
# we could use .data here, because all the operations just happen before the real training
|
|
||||||
# loop, so we don't need to track these operations in the autograd graph.
|
|
||||||
target = torch.nn.Parameter(
|
|
||||||
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
|
||||||
target_sharding_spec).detach().clone())
|
|
||||||
|
|
||||||
assert hasattr(target_module, atoms[-1])
|
assert hasattr(target_module, atoms[-1])
|
||||||
setattr(target_module, atoms[-1], target)
|
setattr(target_module, atoms[-1], target)
|
||||||
|
_add_hook_for_grad_communication(node, target)
|
||||||
|
|
||||||
comm_actions = node.best_strategy.communication_actions
|
|
||||||
for operation_data, comm_action in comm_actions.items():
|
|
||||||
comm_spec_to_use = comm_action.comm_spec
|
|
||||||
# register hook to the parameters
|
|
||||||
if isinstance(node._meta_data, torch.nn.parameter.Parameter) and comm_action.comm_type == CommType.HOOK:
|
|
||||||
|
|
||||||
def wrapper(param, comm_spec, stream, overlap):
|
|
||||||
|
|
||||||
def hook_fn(grad):
|
|
||||||
if overlap:
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
_all_reduce(grad, comm_spec, async_op=True)
|
|
||||||
else:
|
|
||||||
_all_reduce(grad, comm_spec, async_op=False)
|
|
||||||
|
|
||||||
param.register_hook(hook_fn)
|
|
||||||
|
|
||||||
wrapper(target, comm_spec_to_use, reduction_stream, overlap=overlap)
|
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
@ -471,14 +488,14 @@ def implicit_comm_action_apply(gm: torch.fx.GraphModule):
|
|||||||
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||||
solution: List[int],
|
solution: List[int],
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
strategies_constructor: StrategiesConstructor = None,
|
strategies_constructor: StrategiesConstructor,
|
||||||
overlap=False):
|
overlap=False):
|
||||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = _solution_annotatation(
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
|
||||||
gm, solution, strategies_constructor)
|
gm, solution, strategies_constructor)
|
||||||
gm = _size_value_converting(gm, device_mesh)
|
gm = size_value_converting_pass(gm, device_mesh)
|
||||||
gm = _node_args_converting(gm, device_mesh)
|
gm = node_args_converting_pass(gm, device_mesh)
|
||||||
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
# TODO: the pass below should be uncommented after the implementation of implicit_comm_action_apply_pass completed.
|
||||||
# gm = implicit_comm_action_apply(gm)
|
# gm = implicit_comm_action_apply(gm)
|
||||||
gm = _module_params_sharding(gm, device_mesh, overlap=overlap)
|
gm = module_params_sharding_pass(gm, device_mesh, overlap=overlap)
|
||||||
|
|
||||||
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
return gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict
|
||||||
|
@ -0,0 +1,54 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import node_args_converting_pass
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.view(4, 4, 2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def insert_narrow(gm, x_node):
|
||||||
|
graph = gm.graph
|
||||||
|
with graph.inserting_after(x_node):
|
||||||
|
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
|
||||||
|
view_node = list(x_node.users.keys())[0]
|
||||||
|
new_args = list(view_node.args)
|
||||||
|
new_args[0] = shard_node
|
||||||
|
view_node.args = tuple(new_args)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def test_node_args_converting_pass():
|
||||||
|
model = TestModule()
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
||||||
|
input = torch.rand(4, 8)
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
|
||||||
|
x_node = list(graph.nodes)[0]
|
||||||
|
view_node = list(graph.nodes)[1]
|
||||||
|
sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
||||||
|
setattr(x_node, 'sharding_spec', sharding_spec)
|
||||||
|
setattr(view_node, 'sharding_spec', sharding_spec)
|
||||||
|
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm = node_args_converting_pass(gm, device_mesh)
|
||||||
|
gm = insert_narrow(gm, x_node)
|
||||||
|
gm.recompile()
|
||||||
|
output = gm(input)
|
||||||
|
assert output.shape == torch.Size([2, 4, 2])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_node_args_converting_pass()
|
@ -0,0 +1,65 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import size_value_converting_pass
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.tracer import ColoTracer
|
||||||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
class TestModule(torch.nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
size = x.size()
|
||||||
|
return size
|
||||||
|
|
||||||
|
|
||||||
|
def insert_narrow(gm, x_node):
|
||||||
|
graph = gm.graph
|
||||||
|
with graph.inserting_after(x_node):
|
||||||
|
shard_node = graph.create_node('call_method', 'narrow', args=(x_node, 0, 0, 2), kwargs={})
|
||||||
|
size_node = list(x_node.users.keys())[0]
|
||||||
|
size_node.args = (shard_node,)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def recover_narrow(gm, narrow_node):
|
||||||
|
graph = gm.graph
|
||||||
|
size_node = list(graph.nodes)[2]
|
||||||
|
x_node = narrow_node.args[0]
|
||||||
|
size_node.args = (x_node,)
|
||||||
|
graph.erase_node(narrow_node)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def test_size_value_converting_pass():
|
||||||
|
model = TestModule()
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||||
|
meta_args = {'x': torch.rand(4, 8).to('meta')}
|
||||||
|
input = torch.rand(4, 8)
|
||||||
|
tracer = ColoTracer()
|
||||||
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
||||||
|
|
||||||
|
x_node = list(graph.nodes)[0]
|
||||||
|
x_sharding_spec = ShardingSpec(device_mesh, entire_shape=(4, 8), dim_partition_dict={0: [0]})
|
||||||
|
setattr(x_node, 'sharding_spec', x_sharding_spec)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm = insert_narrow(gm, x_node)
|
||||||
|
gm.recompile()
|
||||||
|
size = gm(input)
|
||||||
|
assert size == torch.Size([2, 8])
|
||||||
|
|
||||||
|
narrow_node = list(gm.graph.nodes)[1]
|
||||||
|
gm = recover_narrow(gm, narrow_node)
|
||||||
|
gm = size_value_converting_pass(gm, device_mesh)
|
||||||
|
gm = insert_narrow(gm, x_node)
|
||||||
|
gm.recompile()
|
||||||
|
size = gm(input)
|
||||||
|
assert size == torch.Size([4, 8])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_size_value_converting_pass()
|
@ -1,12 +1,9 @@
|
|||||||
from faulthandler import disable
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from xml.dom import WrongDocumentErr
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
from colossalai.auto_parallel.tensor_shard.node_handler import LinearFunctionHandler, LinearModuleHandler
|
||||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||||
|
Loading…
Reference in New Issue
Block a user