mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 05:49:55 +00:00
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.
This commit is contained in:
48
colossalai/fx/profiler/experimental/shard_utils.py
Normal file
48
colossalai/fx/profiler/experimental/shard_utils.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# for PyTorch 1.11 compatibility uses
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.fx import GraphModule, Node
|
||||
|
||||
from ..._compatibility import compatibility
|
||||
|
||||
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_in(n: Node) -> bool:
|
||||
"""A helper function to calculate `fwd_in`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
save_fwd_in (bool): the result of `save_fwd_in`
|
||||
"""
|
||||
return n.meta['save_fwd_in']
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_tmp(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_tmp`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_tmp (int): the result of `fwd_tmp`
|
||||
"""
|
||||
return n.meta["fwd_mem_tmp"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def calculate_fwd_out(n: Node) -> int:
|
||||
"""A helper function to calculate `fwd_out`
|
||||
|
||||
Args:
|
||||
n (Node): a node from the graph
|
||||
|
||||
Returns:
|
||||
fwd_out (int): the result of `fwd_out`
|
||||
"""
|
||||
return n.meta['fwd_mem_out']
|
Reference in New Issue
Block a user