[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -7,7 +7,7 @@ from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
__all__ = ["LiveVariable", "LiveVariableVector", "LiveStage", "GraphAnalyser"]
@dataclass
@@ -15,6 +15,7 @@ class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
@@ -55,6 +56,7 @@ class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
@@ -62,7 +64,6 @@ class LiveStage:
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@@ -105,18 +106,18 @@ class GraphAnalyser:
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplicate var
is_inplace = False
if node.op == 'call_function':
if node.op == "call_function":
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
if node.kwargs.get("inplace", False):
is_inplace = True
elif node.op == 'call_module':
elif node.op == "call_module":
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
if getattr(module, "inplace", False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
getattr(node, "_meta_data", None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
@@ -138,10 +139,12 @@ class GraphAnalyser:
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
stage = LiveStage(
name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy(),
)
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):