mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -7,7 +7,7 @@ from torch.fx.node import Node
|
||||
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||||
__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,6 +15,7 @@ class LiveVariable:
|
||||
"""
|
||||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
is_inplace: bool
|
||||
@@ -55,6 +56,7 @@ class LiveStage:
|
||||
"""
|
||||
LiveStage is a data structure to record the living variables at this current node.
|
||||
"""
|
||||
|
||||
name: str
|
||||
node: Node
|
||||
all_live_vars: LiveVariableVector
|
||||
@@ -62,7 +64,6 @@ class LiveStage:
|
||||
|
||||
|
||||
class GraphAnalyser:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self._gm = gm
|
||||
self._graph = gm.graph
|
||||
@@ -105,18 +106,18 @@ class GraphAnalyser:
|
||||
# detect whether the current op is an in-place op
|
||||
# if it is an in-place op, we would deem it as a duplicate var
|
||||
is_inplace = False
|
||||
if node.op == 'call_function':
|
||||
if node.op == "call_function":
|
||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||
if node.kwargs.get('inplace', False):
|
||||
if node.kwargs.get("inplace", False):
|
||||
is_inplace = True
|
||||
elif node.op == 'call_module':
|
||||
elif node.op == "call_module":
|
||||
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||||
module = get_node_module(node)
|
||||
if getattr(module, 'inplace', False):
|
||||
if getattr(module, "inplace", False):
|
||||
is_inplace = True
|
||||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
getattr(node, "_meta_data", None)
|
||||
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
@@ -138,10 +139,12 @@ class GraphAnalyser:
|
||||
# this should be completed if we are able to trace the backward compute graph
|
||||
|
||||
# add this stage to liveness dict
|
||||
stage = LiveStage(name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
stage = LiveStage(
|
||||
name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy(),
|
||||
)
|
||||
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
|
||||
replace = False
|
||||
for index, prev_stage in enumerate(liveness_list):
|
||||
|
Reference in New Issue
Block a user