[fx] Fix wrong index in annotation and minimal flops in ckpt solver (#1521)

* [fx] fix wrong variable name in solver rotor

* [fx] fix wrong variable name in solver rotor

* [fx] fix the discretize bug

* [fx] fix the first op in activation checkpoint codegen

* [fx] fix some bugs of ckpt solver

* [fx] modify test_ckpt_torchvision

* [fx] set sequence to __sequence__ attr of GraphModule

* [fx] docstring modification

* [fx] remove performance test
This commit is contained in:
Boyuan Yao
2022-08-31 18:10:48 +08:00
committed by GitHub
parent 07f5a4e054
commit b231430bcb
2 changed files with 27 additions and 11 deletions

View File

@@ -1,6 +1,7 @@
from typing import List, Set, Tuple, Dict
import torch
from torch.fx import GraphModule, Node
from colossalai.fx.graph_module import ColoGraphModule
import math
from .linearize import linearize
from .utils import *
@@ -131,10 +132,10 @@ def _construct_chain(node_dict: Dict[int, Node], data: torch.Tensor, mem_unit: i
x_sizes.append(node_dict[key][-1].meta['tensor_meta'].numel *
torch.tensor([], dtype=node_dict[key][-1].meta['tensor_meta'].dtype).element_size())
for node in node_dict[key]:
fwd_time[-1] += node.__flops__
fwd_time[-1] += max(node.__flops__, 1)
# currently we haven't patched the backward flops count
bwd_time[-1] += node.__flops__ * 2
bwd_time[-1] += max(node.__flops__ * 2, 2)
xbar_sizes[-1] += node.__activation__
@@ -164,16 +165,16 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
elif isinstance(op, ForwardEnable):
in_ckpt = False
for idx in ckpt_region:
for node in node_dict[idx]:
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
ckpt_region = []
elif isinstance(op, ForwardCheck):
for idx in ckpt_region:
for node in node_dict[idx]:
for node_idx in ckpt_region:
for node in node_dict[node_idx]:
setattr(node, "activation_checkpoint", ckpt_idx)
ckpt_idx += 1
@@ -185,7 +186,19 @@ def _annotate_from_sequence(sequence: Sequence, node_dict: Dict[int, Node]) -> G
ckpt_region.append(idx)
def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> GraphModule:
def solver_rotor(gm: ColoGraphModule, data: torch.Tensor, mem_limit: int, mem_slots: int = 500) -> ColoGraphModule:
"""solver that automatically find activation checkpoint in rotor's manner
Args:
gm (ColoGraphModule): ColoGraphModule generated by tracing model.
data (torch.Tensor): input data.
mem_limit (int): memory budget in Byte.
mem_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
Returns:
ColoGraphModule: annotated ColoGraphModuled with __sequence__ attribute
"""
node_dict = linearize(gm)
mem_unit = mem_limit // mem_slots
MetaInfoProp(gm).run(data)
@@ -193,4 +206,7 @@ def solver_rotor(gm: GraphModule, data: torch.Tensor, mem_limit: int, mem_slots:
opt_table = _compute_table(chain, mem_slots)
sequence = _rec(chain, 0, chain.length, mem_slots - chain.cweight[0], opt_table)
_annotate_from_sequence(sequence, node_dict)
# set __sequence__ attribute to GraphModule
setattr(gm, "__sequence__", sequence)
return gm