[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)

* [autoparallel] hook node meta on graph nodes for checkpoint solver

* [autoparallel] polish code

* [autoparallel] restore some node handlers

* colossalai/auto_parallel/passes/meta_info_prop.py

* [autoparallel] remove some unused import

* [autoparallel] hook bwd_mem_out
This commit is contained in:
Boyuan Yao
2023-01-02 16:25:18 +08:00
committed by GitHub
parent c8c79102f0
commit ab38aebace
6 changed files with 132 additions and 76 deletions

View File

@@ -1,15 +1,14 @@
import uuid
from dataclasses import asdict
from typing import Any, Dict, List, NamedTuple, Tuple
from typing import List
import torch
import torch.fx
from torch.fx import GraphModule
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler import MetaInfo
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
@@ -68,7 +67,7 @@ class MetaInfoProp:
"""
graph_info = GraphInfo()
out = _normalize_tuple(getattr(node, '_meta_data', None))
graph_info.fwd_out = list(out)
graph_info.fwd_out = list(out) if out[0] is not None else []
node.meta = {**asdict(graph_info)}
@compatibility(is_backward_compatible=False)
@@ -97,7 +96,7 @@ class MetaInfoProp:
"""
Handle other kind of nodes
"""
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
graph_info = GraphInfo()
meta_info = node.best_metainfo
meta_info: MetaInfo
@@ -158,5 +157,13 @@ class MetaInfoProp:
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation
# fetch flop information
# here we use fwd_time and bwd_time to deal with the case that
# communication cost is a float
compute_cost = meta_info.compute_cost
graph_info.fwd_time = compute_cost.fwd
graph_info.bwd_time = compute_cost.bwd
node.meta = {**asdict(graph_info)}