mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[autoparallel] add split handler (#2032)
* [autoparallel] add split handler * add numerical test and runtime passes
This commit is contained in:
@@ -13,6 +13,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.comm_spec import CommSpec
|
||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
@@ -27,6 +28,23 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
|
||||
return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
|
||||
|
||||
|
||||
def runtime_apply_for_iterable_object(node: Node, origin_dict: Dict, input_dict: Dict, node_index: int,
|
||||
user_node_index: int):
|
||||
"""
|
||||
This method will be invoked during runtime to do the shape consistency, which makes sure the activations in type of tuple or list
|
||||
is converted into the user node expected form.
|
||||
"""
|
||||
rst = []
|
||||
for index, (origin_sharding_spec,
|
||||
target_sharding_spec) in enumerate(zip(origin_dict[node_index],
|
||||
input_dict[node_index][user_node_index])):
|
||||
rst.append(
|
||||
shape_consistency_manager.apply_for_autoparallel_runtime(node[index], origin_sharding_spec,
|
||||
target_sharding_spec))
|
||||
rst = type(node)(rst)
|
||||
return rst
|
||||
|
||||
|
||||
def runtime_comm_spec_apply(tensor: torch.Tensor, comm_actions_dict: Dict, node_index: int, op_data_name: str):
|
||||
"""
|
||||
This method will be invoked during runtime to apply the comm action following the instruction of comm spec.
|
||||
@@ -81,13 +99,34 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
|
||||
continue
|
||||
|
||||
for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
if isinstance(node.sharding_spec, (list, tuple)):
|
||||
assert isinstance(
|
||||
node.target_sharding_specs,
|
||||
(list,
|
||||
tuple)), 'target sharding specs should be tuple or list when node.sharding_spec is tuple or list'
|
||||
total_difference = 0
|
||||
for sharding_spec, target_sharding_spec in zip(node.sharding_spec,
|
||||
node.target_sharding_specs[user_node_index]):
|
||||
total_difference += sharding_spec.sharding_sequence_difference(target_sharding_spec)
|
||||
if total_difference == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply_for_iterable_object,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
else:
|
||||
assert isinstance(node.sharding_spec,
|
||||
ShardingSpec), 'node.sharding_spec should be type of ShardingSpec, tuple or list.'
|
||||
if node.sharding_spec.sharding_sequence_difference(node.target_sharding_specs[user_node_index]) == 0:
|
||||
continue
|
||||
with mod_graph.inserting_before(user_node):
|
||||
shape_consistency_node = mod_graph.create_node('call_function',
|
||||
runtime_apply,
|
||||
args=(node, origin_dict_node, input_dict_node,
|
||||
node_to_index_dict[node], user_node_index))
|
||||
|
||||
new_args = list(user_node.args)
|
||||
new_kwargs = dict(user_node.kwargs)
|
||||
# the origin node may be a positional argument or key word argument of user node
|
||||
|
Reference in New Issue
Block a user