[autoparallel] support origin activation ckpt on autoprallel system (#2468)

This commit is contained in:
YuliangLiu0306
2023-01-16 16:25:13 +08:00
committed by GitHub
parent 3a21485ead
commit 67e1912b59
4 changed files with 111 additions and 5 deletions

View File

@@ -128,6 +128,8 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
runtime_apply,
args=(node, origin_dict_node, input_dict_node,
node_to_index_dict[node], user_node_index))
if 'activation_checkpoint' in user_node.meta:
shape_consistency_node.meta['activation_checkpoint'] = user_node.meta['activation_checkpoint']
new_args = list(user_node.args)
new_kwargs = dict(user_node.kwargs)
@@ -208,6 +210,37 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
# substitute the origin node with comm_spec_apply_node
new_kwargs[str(node)] = comm_spec_apply_node
user.kwargs = new_kwargs
if 'activation_checkpoint' in node.meta:
comm_spec_apply_node.meta['activation_checkpoint'] = node.meta['activation_checkpoint']
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""
mod_graph = gm.graph
nodes = tuple(mod_graph.nodes)
for node in nodes:
if not hasattr(node.meta, 'activation_checkpoint'):
from .runtime_preparation_pass import size_processing
user_act_annotation = -1
input_act_annotation = -1
for user_node in node.users.keys():
if 'activation_checkpoint' in user_node.meta:
user_act_annotation = user_node.meta['activation_checkpoint']
break
for input_node in node._input_nodes.keys():
if 'activation_checkpoint' in input_node.meta:
input_act_annotation = input_node.meta['activation_checkpoint']
break
if user_act_annotation == input_act_annotation and user_act_annotation != -1:
node.meta['activation_checkpoint'] = user_act_annotation
return gm