mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[autoparallel] modify comm nodes' memory cost in construct chain (#2263)
* [autoparallel] align the data_ptr with the old version of auto activation checkpoint pipeline * [autoparallel] using fwd_time and bwd_time instead of fwd_flop and bwd_flop * [autoparallel] specifycomm nodes' memory cost in construct chain
This commit is contained in:
parent
1ea99b869e
commit
5c2ef9fc76
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Tuple
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply, runtime_comm_spec_apply
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
from colossalai.fx.profiler import (
|
from colossalai.fx.profiler import (
|
||||||
activation_size,
|
activation_size,
|
||||||
@ -131,8 +132,14 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||||||
fwd_mem_peak = 0
|
fwd_mem_peak = 0
|
||||||
for n in node:
|
for n in node:
|
||||||
assert isinstance(n, Node), f'{n} is not a Node'
|
assert isinstance(n, Node), f'{n} is not a Node'
|
||||||
|
if n.target == runtime_apply or n.target == runtime_comm_spec_apply:
|
||||||
|
# in this case we need to calculate memory usage directly based on the statics that hooked in node.meta
|
||||||
|
xbar += n.meta['fwd_mem_out']
|
||||||
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'])
|
||||||
|
else:
|
||||||
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
xbar += calculate_fwd_tmp(n) + calculate_fwd_out(n)
|
||||||
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
fwd_mem_peak = max(fwd_mem_peak, xbar + n.meta['fwd_mem_tmp'] + cls._extract_unused_output(n))
|
||||||
|
|
||||||
# minimum flop count is required
|
# minimum flop count is required
|
||||||
ftime += max(calculate_fwd_time(n), 1.0)
|
ftime += max(calculate_fwd_time(n), 1.0)
|
||||||
btime += max(calculate_bwd_time(n), 1.0)
|
btime += max(calculate_bwd_time(n), 1.0)
|
||||||
|
@ -151,6 +151,7 @@ class MetaInfoProp:
|
|||||||
# fetch other memory informations
|
# fetch other memory informations
|
||||||
memory_cost = meta_info.memory_cost
|
memory_cost = meta_info.memory_cost
|
||||||
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
graph_info.fwd_mem_tmp = memory_cost.fwd.temp
|
||||||
|
graph_info.fwd_mem_out = memory_cost.fwd.activation
|
||||||
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
graph_info.bwd_mem_tmp = memory_cost.bwd.temp
|
||||||
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
graph_info.bwd_mem_out = memory_cost.bwd.activation
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ def calculate_fwd_time(n: Node) -> float:
|
|||||||
fwd_time (float): the result of `fwd_time`
|
fwd_time (float): the result of `fwd_time`
|
||||||
"""
|
"""
|
||||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
return n.meta["fwd_flop"]
|
return n.meta["fwd_time"]
|
||||||
|
|
||||||
|
|
||||||
def calculate_bwd_time(n: Node) -> float:
|
def calculate_bwd_time(n: Node) -> float:
|
||||||
@ -111,4 +111,4 @@ def calculate_bwd_time(n: Node) -> float:
|
|||||||
bwd_time (float): the result of `bwd_time`
|
bwd_time (float): the result of `bwd_time`
|
||||||
"""
|
"""
|
||||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
return n.meta["bwd_flop"]
|
return n.meta["bwd_time"]
|
||||||
|
Loading…
Reference in New Issue
Block a user