[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
This commit is contained in:
Super Daniel 2022-10-26 14:24:41 +08:00 committed by GitHub
parent 63f250bbd4
commit 0584654c79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 177 additions and 122 deletions

View File

@ -1,7 +1,9 @@
import math
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
__all__ = ['chen_greedy']

View File

@ -1,15 +1,17 @@
import math
import sys
from typing import List, Tuple
from colossalai.fx.profiler.memory import calculate_fwd_in
from torch.fx import Node
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
import math
from .linearize import linearize
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
from colossalai.fx.graph_module import ColoGraphModule
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
from colossalai.logging import get_dist_logger
from .linearize import linearize
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
# global vairable to indicate whether the solver is failed
SOLVER_FAILED = False
@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
# build module if module not found
except ModuleNotFoundError:
import subprocess
import os
import subprocess
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
this_dir = os.path.dirname(os.path.abspath(__file__))
result = subprocess.Popen(

View File

@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_flatten
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
@compatibility(is_backward_compatible=True)
class ConcreteInfoProp(torch.fx.Interpreter):

View File

@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
import torch
import torch.fx
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
profile_function, profile_method, profile_module)
from torch.fx.node import Argument, Node, Target
from torch.utils._pytree import tree_map
from colossalai.fx._compatibility import compatibility
from colossalai.fx.profiler import (
GraphInfo,
activation_size,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_tmp,
profile_function,
profile_method,
profile_module,
)
@compatibility(is_backward_compatible=True)
class TensorMetadata(NamedTuple):

View File

@ -1,12 +1,18 @@
from .._compatibility import is_compatible_with_meta
if is_compatible_with_meta():
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .opcount import flop_mapping
from .profiler import profile_function, profile_method, profile_module
from .shard_utils import (
calculate_bwd_time,
calculate_fwd_in,
calculate_fwd_out,
calculate_fwd_time,
calculate_fwd_tmp,
)
from .tensor import MetaTensor
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
from .dataflow import GraphInfo
from .memory import activation_size, is_inplace, parameter_size
from .memory_utils import activation_size, is_inplace, parameter_size

View File

@ -6,7 +6,7 @@ from typing import Dict, List
from torch.fx import Graph, Node
from .._compatibility import compatibility
from .memory import activation_size, is_inplace
from .memory_utils import activation_size, is_inplace
class Phase(Enum):

View File

@ -1,5 +1,5 @@
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
from .profiler import profile_function, profile_method, profile_module
from .profiler_function import *
from .profiler_module import *
from .registry import meta_profiler_function, meta_profiler_module
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp

View File

@ -5,7 +5,7 @@ import torch
from torch.fx.node import Argument, Target
from ..._compatibility import compatibility
from ..memory import activation_size
from ..memory_utils import activation_size
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
from .registry import meta_profiler_function, meta_profiler_module

View File

@ -0,0 +1,71 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from .._compatibility import compatibility, is_compatible_with_meta
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
if out.is_quantized:
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
else:
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
Args:
node (Node): torch.fx.Node
Returns:
bool: indicates whether this op is inplace
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace

View File

@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
from .memory import activation_size, parameter_size
from .memory_utils import activation_size, parameter_size
from .opcount import flop_mapping
from .tensor import MetaTensor

View File

@ -1,58 +1,18 @@
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from torch.fx import Node
from .._compatibility import compatibility, is_compatible_with_meta
from .memory_utils import activation_size
if is_compatible_with_meta():
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
__all__ = [
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
]
@compatibility(is_backward_compatible=True)
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
for element in out:
act_size += activation_size(element)
return act_size
@compatibility(is_backward_compatible=True)
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate parameter size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The parameter size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=False)
def calculate_fwd_in(n: Node) -> int:
"""A helper function to calculate `fwd_in`
"""A helper function to calculate `fwd_in` (with sharding spec)
Args:
n (Node): a node from the graph
@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
Returns:
fwd_in (int): the result of `fwd_in`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
return activation_size(n.meta["fwd_in"])
@compatibility(is_backward_compatible=False)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
"""A helper function to calculate `fwd_tmp` (with sharding spec)
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
Args:
@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
fwd_tmp (int): the result of `fwd_tmp`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def is_relu_like_node(n: Node) -> bool:
"""Check if a node is a ReLU-like node.
ReLU-like nodes have the following properties:
@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
return 0
@compatibility(is_backward_compatible=False)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
"""A helper function to calculate `fwd_out` (with sharding spec)
Args:
n (Node): a node from the graph
@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int:
fwd_out (int): the result of `fwd_out`
"""
# TODO(super-dainiu): should divide the memory by sharding spec
def intersect(a, b):
return {k: a[k] for k in a if k in b}
@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int:
return activation_size(intersect(fwd_in, fwd_out))
def is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
def calculate_fwd_time(n: Node) -> float:
"""A helper function to calculate `fwd_time` (with sharding spec)
Args:
node (Node): torch.fx.Node
n (Node): a node from the graph
Returns:
bool: indicates whether this op is inplace
fwd_time (float): the result of `fwd_time`
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
if is_compatible_with_meta():
from .constants import ALIAS_ATEN
if n.target in ALIAS_ATEN:
inplace = True
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["fwd_flop"]
return inplace
def calculate_bwd_time(n: Node) -> float:
"""A helper function to calculate `bwd_time` (with sharding spec)
Args:
n (Node): a node from the graph
Returns:
bwd_time (float): the result of `bwd_time`
"""
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
return n.meta["bwd_flop"]

View File

@ -1,7 +1,5 @@
from colossalai.fx.profiler.memory import activation_size
import torch
from torch.fx import Node, Graph
from torch.fx.graph import _Namespace
from torch.fx import Graph, Node
from torch.utils._pytree import tree_map

View File

@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union
import torch
import torch.fx
import torchvision.models as tm
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from gpt_utils import gpt2_medium, gpt2_xl
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor