mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from torch.fx.graph import CodeGen
|
||||
except:
|
||||
pass
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_format_target,
|
||||
@@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
|
||||
"""
|
||||
Check if the node could end the ckpt region at `ckpt_level`
|
||||
"""
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
return node.meta['info'].to_recompute[ckpt_level] is not None
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
|
||||
return True
|
||||
|
||||
|
||||
@@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
|
||||
current_region = None
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
if len(node.meta['info'].to_recompute) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
|
||||
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
|
||||
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
|
||||
|
||||
# this activation checkpoint label is not set yet
|
||||
# meaning this is the first node of the activation ckpt region
|
||||
@@ -152,12 +156,12 @@ def emit_ckpt_func(body,
|
||||
|
||||
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
|
||||
# the label will be '0_1_1'
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
|
||||
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
|
||||
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
|
||||
ckpt_func.append(f'{ckpt_fn_def}\n')
|
||||
|
||||
# if there is more level to fetch
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
|
||||
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
|
||||
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
@@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
|
||||
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
|
||||
start_idx = [item[0] for item in ckpt_regions]
|
||||
end_idx = [item[1] for item in ckpt_regions]
|
||||
|
||||
node_list = list(nodes)
|
||||
|
||||
node_idx = 0
|
||||
|
Reference in New Issue
Block a user