[autoparallel] add split handler (#2032)

* [autoparallel] add split handler

* add numerical test and runtime passes
This commit is contained in:
YuliangLiu0306
2022-11-29 11:03:51 +08:00
committed by GitHub
parent 28aa9a4294
commit 0dbcd4a6f5
9 changed files with 500 additions and 22 deletions

View File

@@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.comm_spec import CommSpec
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
shape_consistency_manager = ShapeConsistencyManager()
@@ -27,6 +28,23 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
user_node_index: int):
"""
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
is converted into the user node expected form.
"""
rst = []
for index, (origin_sharding_spec,
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
input_dict[node_index][user_node_index])):
rst.append(
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
target_sharding_spec))
rst = type(node)(rst)
return rst
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
"""
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
@@ -81,13 +99,34 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
continue
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if isinstance(node.sharding_spec, (list, tuple)):
assert isinstance(
node.target_sharding_specs,
(list,
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
total_difference = 0
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
node.target_sharding_specs[user_node_index]):
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
if total_difference == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply_for_iterable_object,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
else:
assert isinstance(node.sharding_spec,
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
continue
with mod_graph.inserting_before(user_node):
shape_consistency_node = mod_graph.create_node('call_function',
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
# the origin node may be a positional argument or key word argument of user node