[fx/profiler] assigned UUID to each unrecorded tensor/ improved performance on GPT-2 (#1679)

* [fx/profiler] modify data_ptr into uuid for all tensors.

* [fx] modify uuid.

* [fx/profiler] tune performance on GPT-2.

* [fx] updates.

* [fx] debug.

* [fx] debug.

* [fx] cuda.
This commit is contained in:
Super Daniel
2022-10-11 11:03:35 +08:00
committed by GitHub
parent 0df5034a36
commit 3dd6994427
13 changed files with 262 additions and 94 deletions

View File

@@ -1,7 +1,7 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import Enum
from functools import partial
from typing import Dict
from typing import Dict, List
from torch.fx import Graph, Node
from .memory import activation_size, is_inplace
@@ -39,16 +39,25 @@ class GraphInfo:
bwd_flop (int): The backward FLOPs of a certain node.
bwd_time (float): The real backward time (s) of a certain node.
save_fwd_in (bool): The decision variable of whether to save the fwd_mem_out of parent nodes.
fwd_in (List): See the above illustration.
fwd_tmp (List): See the above illustration.
fwd_out (List): See the above illustration.
fwd_mem_tmp (int): See the above illustration.
fwd_mem_out (int): See the above illustration.
bwd_mem_tmp (int): See the above illustration.
bwd_mem_out (int): See the above illustration.
"""
# TODO(super-dainiu): removed redundant items, currently all of them are necessary for development
fwd_flop: int = 0
fwd_time: float = 0.0
bwd_flop: int = 0
bwd_time: float = 0.0
save_fwd_in: bool = False
fwd_in: List = field(default_factory=list)
fwd_tmp: List = field(default_factory=list)
fwd_out: List = field(default_factory=list)
fwd_mem_tmp: int = 0
fwd_mem_out: int = 0
bwd_mem_tmp: int = 0
@@ -60,10 +69,6 @@ def is_phase(n: Node, phase: Phase) -> bool:
return n.meta['phase'] == phase
def is_saved(n: Node):
return len(n.meta['saved_tensor'])
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
"""Analyze the autograd node dependencies and find out the memory usage.
Basically the input graph should have all nodes marked for keyword `phase`.
@@ -113,9 +118,9 @@ def autograd_graph_analysis(graph: Graph) -> GraphInfo:
# Otherwise, the tensor belongs to `fwd_mem_tmp`. If we checkpoint
# the node, `fwd_mem_tmp` can be freed.
if is_phase(n, Phase.PLACEHOLDER):
graph_info.save_fwd_in |= activation_size(n.meta['saved_tensor']) > 0
graph_info.fwd_in += n.meta['saved_tensor']
if is_phase(n, Phase.FORWARD):
graph_info.fwd_mem_tmp += activation_size(n.meta['saved_tensor'])
graph_info.fwd_tmp += n.meta['saved_tensor']
elif is_phase(n, Phase.BACKWARD):
if len(n.users):
graph_info.bwd_mem_tmp = max(graph_info.bwd_mem_tmp, _peak_memory(deps))