[autoparallel] Add pofo sequence annotation (#1637)

* [autoparallel] annotate pofo sequence

* [autoparallel] remove unused print

* [autoparallel] fix some code
This commit is contained in:
Boyuan Yao
2022-09-24 01:52:57 +08:00
committed by GitHub
parent 04bbabeea8
commit f921733621
3 changed files with 133 additions and 6 deletions

View File

@@ -145,7 +145,7 @@ def _find_ckpt_regions(nodes: List[Node]):
def _find_offload_regions(nodes: List[Node]):
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
list in the form of [idx, offload_input, offload_bar]. idx indicates the offload
region's index, offload_input is a bool type indicates whether we need to offload
the input, offload_bar is a bool type indicates whether we need to offload all the
intermediate x_bars of this region.
@@ -157,7 +157,7 @@ def _find_offload_regions(nodes: List[Node]):
current_region = None
for idx, node in enumerate(nodes):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), list):
act_offload_label = node.activation_offload
if current_region == None: