[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
This commit is contained in:
Super Daniel
2022-10-26 14:24:41 +08:00
committed by GitHub
parent 63f250bbd4
commit 0584654c79
14 changed files with 177 additions and 122 deletions

View File

@@ -1,7 +1,9 @@
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']

View File

@@ -1,15 +1,17 @@
import math
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
@@ -18,7 +20,7 @@ SOLVER_FAILED = False
# https://gitlab.inria.fr/hiepacs/rotor
# paper link: https://hal.inria.fr/hal-02352969
def _compute_table(chain: Chain, mmax) -> Tuple:
"""Returns the optimal table: a tuple containing:
"""Returns the optimal table: a tuple containing:
Opt[m][lmin][lmax] with lmin = 0...chain.length
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
@@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
"""Get the forward xbar of a node
Args:
node (List[Node]): List of torch.fx Node,
node (List[Node]): List of torch.fx Node,
indicates a node in linearized graph
Returns:
@@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found
except ModuleNotFoundError:
import subprocess
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(