mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
import math
|
||||
from typing import List, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
|
||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||
|
||||
__all__ = ['chen_greedy']
|
||||
|
@@ -1,15 +1,17 @@
|
||||
import math
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
from colossalai.fx.profiler.memory import calculate_fwd_in
|
||||
|
||||
from torch.fx import Node
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
|
||||
import math
|
||||
from .linearize import linearize
|
||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .linearize import linearize
|
||||
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||
|
||||
# global vairable to indicate whether the solver is failed
|
||||
SOLVER_FAILED = False
|
||||
|
||||
@@ -18,7 +20,7 @@ SOLVER_FAILED = False
|
||||
# https://gitlab.inria.fr/hiepacs/rotor
|
||||
# paper link: https://hal.inria.fr/hal-02352969
|
||||
def _compute_table(chain: Chain, mmax) -> Tuple:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
"""Returns the optimal table: a tuple containing:
|
||||
Opt[m][lmin][lmax] with lmin = 0...chain.length
|
||||
and lmax = lmin...chain.length (lmax is not included) and m = 0...mmax
|
||||
what[m][lmin][lmax] is (True,) if the optimal choice is a chain checkpoint
|
||||
@@ -127,7 +129,7 @@ def _fwd_xbar(node: List[Node]) -> int:
|
||||
"""Get the forward xbar of a node
|
||||
|
||||
Args:
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
node (List[Node]): List of torch.fx Node,
|
||||
indicates a node in linearized graph
|
||||
|
||||
Returns:
|
||||
@@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
|
||||
|
||||
# build module if module not found
|
||||
except ModuleNotFoundError:
|
||||
import subprocess
|
||||
import os
|
||||
import subprocess
|
||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
result = subprocess.Popen(
|
||||
|
Reference in New Issue
Block a user