[autoparallel] modify comm nodes' memory cost in construct chain (#2263)

* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline

* [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop

* [autoparallel] specifycomm nodes' memory cost in construct chain
This commit is contained in:
Boyuan Yao
2023-01-03 11:38:48 +08:00
committed by GitHub
parent 1ea99b869e
commit 5c2ef9fc76
3 changed files with 12 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
from torch import Tensor
from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.profiler import (
activation_size,
@@ -131,8 +132,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
fwd_mem_peak = 0
for n in node:
assert isinstance(n, Node), f'{n} is not a Node'
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
xbar += n.meta['fwd_mem_out']
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
else:
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
# minimum flop count is required
ftime += max(calculate_fwd_time(n), 1.0)
btime += max(calculate_bwd_time(n), 1.0)

View File

@@ -151,6 +151,7 @@ class MetaInfoProp:
# fetch other memory informations
memory_cost = meta_info.memory_cost
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
graph_info.fwd_mem_out = memory_cost.fwd.activation
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
graph_info.bwd_mem_out = memory_cost.bwd.activation