mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[fx/meta/rpc] move _meta_registration.py to fx folder / register fx functions with compatibility checks / remove color debug (#1710)
* [fx] move meta registration * [fx] fix tests. * [fx] fix test. * [fx] fix. * [meta] refactor meta registration.py. * [fx] add compatibility descriptions. * [fx] polish import. * [fx] add a decorator. * [fx] fix tests. * [fx] remove print. * [fx] edit raise error. * [fx] edit raise error. * [fx] add type hint. * [fx] fix import in experimental. * [rpc] remove color debug. * [meta] fix naming.
This commit is contained in:
@@ -2,7 +2,10 @@ from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Dict, List
|
||||
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from .._compatibility import compatibility
|
||||
from .memory import activation_size, is_inplace
|
||||
|
||||
|
||||
@@ -12,6 +15,7 @@ class Phase(Enum):
|
||||
PLACEHOLDER = 2
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""
|
||||
@@ -69,6 +73,7 @@ def is_phase(n: Node, phase: Phase) -> bool:
|
||||
return n.meta['phase'] == phase
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def autograd_graph_analysis(graph: Graph) -> GraphInfo:
|
||||
"""Analyze the autograd node dependencies and find out the memory usage.
|
||||
Basically the input graph should have all nodes marked for keyword `phase`.
|
||||
|
Reference in New Issue
Block a user