mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
114
colossalai/fx/profiler/shard_utils.py
Normal file
114
colossalai/fx/profiler/shard_utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
from .._compatibility import compatibility, is_compatible_with_meta
|
||||
from .memory_utils import activation_size
|
||||
|
||||
if is_compatible_with_meta():
|
||||
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||
|
||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_in(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_in` (with sharding spec)
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_in (int): the result of `fwd_in`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
return activation_size(n.meta["fwd_in"])
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp` (with sharding spec)
|
||||
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
def is_relu_like_node(n: Node) -> bool:
|
||||
"""Check if a node is a ReLU-like node.
|
||||
ReLU-like nodes have the following properties:
|
||||
- They are either `call_function` or `call_module`
|
||||
- Their output tensors are directly saved for backward
|
||||
- Their input tensors are not saved for backward
|
||||
|
||||
An example is `torch.nn.functional.softmax` which has (forward + backward):
|
||||
def forward(self, input_2):
|
||||
_softmax_default = torch.ops.aten._softmax.default(input_2, None, None); input_2 = None
|
||||
zeros_like_default = torch.ops.aten.zeros_like.default(_softmax_default, dtype = None, layout = None, device = None, pin_memory = None)
|
||||
detach_default = torch.ops.aten.detach.default(_softmax_default); _softmax_default = None
|
||||
_softmax_backward_data_default = torch.ops.aten._softmax_backward_data.default(zeros_like_default, detach_default, None, None); zeros_like_default = detach_default = None
|
||||
detach_default_1 = torch.ops.aten.detach.default(_softmax_backward_data_default); _softmax_backward_data_default = None
|
||||
detach_default_2 = torch.ops.aten.detach.default(detach_default_1); detach_default_1 = None
|
||||
|
||||
Args:
|
||||
n (Node): A node from the graph
|
||||
|
||||
Returns:
|
||||
bool: Whether the node is a ReLU-like node
|
||||
"""
|
||||
if n.op == 'call_function':
|
||||
return n.target in OUTPUT_SAVED_OPS
|
||||
elif n.op == 'call_module':
|
||||
return type(n.graph.owning_module.get_submodule(n.target)) in OUTPUT_SAVED_MOD
|
||||
return False
|
||||
|
||||
if not is_relu_like_node(n):
|
||||
return activation_size(n.meta["fwd_tmp"])
|
||||
return 0
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out` (with sharding spec)
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_out (int): the result of `fwd_out`
|
||||
"""
|
||||
|
||||
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||
def intersect(a, b):
|
||||
return {k: a[k] for k in a if k in b}
|
||||
|
||||
fwd_in = dict()
|
||||
for u in n.users:
|
||||
fwd_in.update({x.uuid: x for x in u.meta["fwd_in"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')})
|
||||
fwd_out = {x.uuid: x for x in n.meta["fwd_out"] if isinstance(x, torch.Tensor) and hasattr(x, 'uuid')}
|
||||
return activation_size(intersect(fwd_in, fwd_out))
|
||||
|
||||
|
||||
def calculate_fwd_time(n: Node) -> float:
|
||||
"""A helper function to calculate `fwd_time` (with sharding spec)
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
Returns:
|
||||
fwd_time (float): the result of `fwd_time`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["fwd_flop"]
|
||||
|
||||
|
||||
def calculate_bwd_time(n: Node) -> float:
|
||||
"""A helper function to calculate `bwd_time` (with sharding spec)
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
Returns:
|
||||
bwd_time (float): the result of `bwd_time`
|
||||
"""
|
||||
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||
return n.meta["bwd_flop"]
|
Reference in New Issue
Block a user