mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[autoparallel] Hook all meta information on ResNet nodes for auto activation checkpoint (#2248)
* [autoparallel] hook node meta on graph nodes for checkpoint solver * [autoparallel] polish code * [autoparallel] restore some node handlers * colossalai/auto_parallel/passes/meta_info_prop.py * [autoparallel] remove some unused import * [autoparallel] hook bwd_mem_out
This commit is contained in:
@@ -1,15 +1,14 @@
|
||||
import uuid
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, NamedTuple, Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.fx.node import Node
|
||||
|
||||
from colossalai.auto_parallel.meta_profiler import MetaInfo
|
||||
from colossalai.fx._compatibility import compatibility, is_compatible_with_meta
|
||||
from colossalai.fx._compatibility import compatibility
|
||||
from colossalai.fx.profiler import GraphInfo
|
||||
from colossalai.fx.profiler.constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
@@ -68,7 +67,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
graph_info = GraphInfo()
|
||||
out = _normalize_tuple(getattr(node, '_meta_data', None))
|
||||
graph_info.fwd_out = list(out)
|
||||
graph_info.fwd_out = list(out) if out[0] is not None else []
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
@@ -97,7 +96,7 @@ class MetaInfoProp:
|
||||
"""
|
||||
Handle other kind of nodes
|
||||
"""
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}"
|
||||
assert hasattr(node, 'best_metainfo'), f"Cannot find best_metainfo in node {node}, {node.op}"
|
||||
graph_info = GraphInfo()
|
||||
meta_info = node.best_metainfo
|
||||
meta_info: MetaInfo
|
||||
@@ -158,5 +157,13 @@ class MetaInfoProp:
|
||||
memory_cost = meta_info.memory_cost
|
||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||
|
||||
# fetch flop information
|
||||
# here we use fwd_time and bwd_time to deal with the case that
|
||||
# communication cost is a float
|
||||
compute_cost = meta_info.compute_cost
|
||||
graph_info.fwd_time = compute_cost.fwd
|
||||
graph_info.bwd_time = compute_cost.bwd
|
||||
|
||||
node.meta = {**asdict(graph_info)}
|
||||
|
Reference in New Issue
Block a user