mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 22:52:25 +00:00
[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)
* [fx/profiler] modify data_ptr into uuid for all tensors. * [fx] modify uuid. * [fx/profiler] tune performance on GPT-2. * [fx] updates. * [fx] debug. * [fx] debug. * [fx] cuda.
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
from torch.fx import Graph, Node
|
||||
from .memory import activation_size, is_inplace
|
||||
|
||||
@@ -39,16 +39,25 @@ class GraphInfo:
|
||||
bwd_flop (int): The backward FLOPs of a certain node.
|
||||
bwd_time (float): The real backward time (s) of a certain node.
|
||||
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
|
||||
fwd_in (List): See the above illustration.
|
||||
fwd_tmp (List): See the above illustration.
|
||||
fwd_out (List): See the above illustration.
|
||||
fwd_mem_tmp (int): See the above illustration.
|
||||
fwd_mem_out (int): See the above illustration.
|
||||
bwd_mem_tmp (int): See the above illustration.
|
||||
bwd_mem_out (int): See the above illustration.
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
|
||||
|
||||
fwd_flop: int = 0
|
||||
fwd_time: float = 0.0
|
||||
bwd_flop: int = 0
|
||||
bwd_time: float = 0.0
|
||||
save_fwd_in: bool = False
|
||||
fwd_in: List = field(default_factory=list)
|
||||
fwd_tmp: List = field(default_factory=list)
|
||||
fwd_out: List = field(default_factory=list)
|
||||
fwd_mem_tmp: int = 0
|
||||
fwd_mem_out: int = 0
|
||||
bwd_mem_tmp: int = 0
|
||||
@@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
def is_saved(n: Node):
|
||||
return len(n.meta['saved_tensor'])
|
||||
|
||||
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
@@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
|
||||
# the node, `fwd_mem_tmp` can be freed.
|
||||
if is_phase(n, Phase.PLACEHOLDER):
|
||||
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
|
||||
graph_info.fwd_in += n.meta['saved_tensor']
|
||||
if is_phase(n, Phase.FORWARD):
|
||||
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
|
||||
graph_info.fwd_tmp += n.meta['saved_tensor']
|
||||
elif is_phase(n, Phase.BACKWARD):
|
||||
if len(n.users):
|
||||
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))
|
||||
|
Reference in New Issue
Block a user