[fx] refactor memory utils and extend shard utils. (#1754)

* [fx] change memory.py to memory_utils.py.

* [fx] add shard utils.

* [fx] fix import.

* [fx] check code style.

* [fx] add comment.

* [autoparallel] first move.

* [fx] add time computations.
This commit is contained in:
Super Daniel
2022-10-26 14:24:41 +08:00
committed by GitHub
parent 63f250bbd4
commit 0584654c79
14 changed files with 177 additions and 122 deletions

View File

@@ -0,0 +1,48 @@
# for PyTorch 1.11 compatibility uses
from typing import Dict, List, Tuple, Union
import torch
from torch.fx import GraphModule, Node
from ..._compatibility import compatibility
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_in(n: Node) -> bool:
"""A helper function to calculate `fwd_in`
Args:
n (Node): a node from the graph
Returns:
save_fwd_in (bool): the result of `save_fwd_in`
"""
return n.meta['save_fwd_in']
@compatibility(is_backward_compatible=True)
def calculate_fwd_tmp(n: Node) -> int:
"""A helper function to calculate `fwd_tmp`
Args:
n (Node): a node from the graph
Returns:
fwd_tmp (int): the result of `fwd_tmp`
"""
return n.meta["fwd_mem_tmp"]
@compatibility(is_backward_compatible=True)
def calculate_fwd_out(n: Node) -> int:
"""A helper function to calculate `fwd_out`
Args:
n (Node): a node from the graph
Returns:
fwd_out (int): the result of `fwd_out`
"""
return n.meta['fwd_mem_out']