mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 20:54:55 +00:00
[autoparallel] added liveness analysis (#1516)
* [autoparallel] added liveness analysis * remove memory cost
This commit is contained in:
parent
9a9ef65313
commit
a0436a62ee
@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler
|
||||
from .dot_handler import DotHandler
|
||||
from .conv_handler import ConvHandler
|
||||
from .sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .graph_analysis import GraphAnalyser
|
||||
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy']
|
||||
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser']
|
||||
|
174
colossalai/auto_parallel/solver/graph_analysis.py
Normal file
174
colossalai/auto_parallel/solver/graph_analysis.py
Normal file
@ -0,0 +1,174 @@
|
||||
from dataclasses import dataclass
|
||||
from torch.fx.node import Node
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from collections import OrderedDict as ODict
|
||||
from typing import List, OrderedDict, Union, Any
|
||||
from colossalai.fx.passes.utils import get_node_module
|
||||
|
||||
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveVariable:
|
||||
"""
|
||||
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
|
||||
"""
|
||||
name: str
|
||||
meta: Union[Any, List[Any]]
|
||||
is_inplace: bool
|
||||
|
||||
|
||||
class LiveVariableVector(list):
|
||||
"""
|
||||
LiveVariableVector is a data structure to store the list of LiveVariable objects.
|
||||
"""
|
||||
|
||||
def exists(self, name) -> bool:
|
||||
"""
|
||||
Check if a variable has already existed in the current list by name.
|
||||
"""
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get(self, name) -> LiveVariable:
|
||||
for var in self:
|
||||
if name == var.name:
|
||||
return var
|
||||
raise KeyError(f"Variable {name} is not found")
|
||||
|
||||
def copy(self) -> "LiveVariableVector":
|
||||
"""
|
||||
Create a copy of this vector
|
||||
"""
|
||||
vector = LiveVariableVector()
|
||||
for var in self:
|
||||
vector.append(var)
|
||||
return vector
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveStage:
|
||||
"""
|
||||
LiveStage is a data structure to record the living variables at this current node.
|
||||
"""
|
||||
name: str
|
||||
node: Node
|
||||
all_live_vars: LiveVariableVector
|
||||
unique_live_vars: LiveVariableVector
|
||||
|
||||
|
||||
class GraphAnalyser:
|
||||
|
||||
def __init__(self, gm: GraphModule):
|
||||
self._gm = gm
|
||||
self._graph = gm.graph
|
||||
|
||||
@property
|
||||
def gm(self) -> GraphModule:
|
||||
"""
|
||||
Return the GraphModule object associated with this analyser.
|
||||
"""
|
||||
return self._gm
|
||||
|
||||
@property
|
||||
def graph(self) -> Graph:
|
||||
"""
|
||||
Return the Graph object associated with this analyser.
|
||||
"""
|
||||
return self._graph
|
||||
|
||||
def liveness_analysis(self) -> OrderedDict[int, LiveStage]:
|
||||
"""
|
||||
Analyse the graph to obtain the variable liveness information. This function returns
|
||||
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
|
||||
"""
|
||||
compute_nodes = self.graph.nodes
|
||||
liveness_dict = ODict()
|
||||
|
||||
# checked: record all variables created since the first stage
|
||||
# all: record the live variables only exist until the current stage.
|
||||
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
|
||||
# unique: record the unique live variables only exist until the current stage.
|
||||
# this is different from `all list` as some variables are duplicated.
|
||||
checked_variables = LiveVariableVector()
|
||||
all_live_variables = LiveVariableVector()
|
||||
unique_live_vars = LiveVariableVector()
|
||||
|
||||
def _add_param_or_buf(node, tensor_type):
|
||||
module = get_node_module(node)
|
||||
|
||||
if tensor_type == 'param':
|
||||
iterator = module.named_parameters()
|
||||
elif tensor_type == 'buffer':
|
||||
iterator = module.named_buffers()
|
||||
else:
|
||||
raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}")
|
||||
|
||||
for name, tensor in iterator:
|
||||
tensor_name = f'{node.name}.{name}'
|
||||
|
||||
if not checked_variables.exists(tensor_name):
|
||||
live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False)
|
||||
unique_live_vars.append(live_tensor)
|
||||
checked_variables.append(live_tensor)
|
||||
all_live_variables.append(live_tensor)
|
||||
|
||||
for idx, node in enumerate(compute_nodes):
|
||||
#############################
|
||||
# find new living variables #
|
||||
#############################
|
||||
# detect whether the current op is an in-place op
|
||||
# if it is an in-place op, we would deem it as a duplciate var
|
||||
is_inplace = False
|
||||
if node.op == 'call_function':
|
||||
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
|
||||
if node.kwargs.get('inplace', False):
|
||||
is_inplace = True
|
||||
elif node.op == 'call_module':
|
||||
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
|
||||
module = get_node_module(node)
|
||||
if getattr(module, 'inplace', False):
|
||||
is_inplace = True
|
||||
|
||||
# add the output var
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace)
|
||||
if not is_inplace:
|
||||
unique_live_vars.append(live_var)
|
||||
checked_variables.append(live_var)
|
||||
all_live_variables.append(live_var)
|
||||
|
||||
# add the model parameters
|
||||
if node.op == 'call_module':
|
||||
_add_param_or_buf(node, tensor_type='param')
|
||||
_add_param_or_buf(node, tensor_type='buffer')
|
||||
|
||||
# add this output variable to the checked list
|
||||
checked_variables.append(live_var)
|
||||
|
||||
# check if any input is not checked yet
|
||||
for arg in node.args:
|
||||
arg_name = str(arg)
|
||||
if not checked_variables.exists(arg_name):
|
||||
meta = getattr(node, '_meta_data', None)
|
||||
live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False)
|
||||
all_live_variables.append(live_var_from_arg)
|
||||
checked_variables.append(live_var_from_arg)
|
||||
unique_live_vars.append(live_var_from_arg)
|
||||
|
||||
# TODO: add the logic to remove live variables
|
||||
# this should be completed if we are able to trace the backward compute graph
|
||||
|
||||
# add this stage to liveness dict
|
||||
stage = LiveStage(name=node.name,
|
||||
node=node,
|
||||
all_live_vars=all_live_variables.copy(),
|
||||
unique_live_vars=unique_live_vars.copy())
|
||||
liveness_dict[idx] = stage
|
||||
return liveness_dict
|
||||
|
||||
def get_alias_set(self):
|
||||
pass
|
@ -160,3 +160,20 @@ def assign_bfs_level_to_nodes(graph: Graph):
|
||||
new_process_list.extend(get_all_consumers(graph, node))
|
||||
nodes_to_process = new_process_list
|
||||
current_level += 1
|
||||
|
||||
|
||||
def get_node_module(node) -> torch.nn.Module:
|
||||
"""
|
||||
Find the module associated with the given node.
|
||||
|
||||
Args:
|
||||
node (torch.fx.Node): a torch.fx.Node object in the fx computation graph
|
||||
|
||||
Returns:
|
||||
torch.nn.Module: the module associated with the given node
|
||||
"""
|
||||
|
||||
assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object'
|
||||
assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}'
|
||||
module = node.graph.owning_module.get_submodule(node.target)
|
||||
return module
|
||||
|
54
tests/test_auto_parallel/test_liveness_analysis.py
Normal file
54
tests/test_auto_parallel/test_liveness_analysis.py
Normal file
@ -0,0 +1,54 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.fx import ColoTracer, ColoGraphModule
|
||||
|
||||
|
||||
class LinearModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(4, 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.linear2 = nn.Linear(4, 4)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x1 = x1 * 2
|
||||
x1 = self.linear1(x1)
|
||||
x1 = self.relu(x1)
|
||||
x1 = self.linear2(x1)
|
||||
out = x1 + x2
|
||||
return out
|
||||
|
||||
|
||||
def test_liveness_analysis():
|
||||
model = LinearModel()
|
||||
tracer = ColoTracer()
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
'x1': torch.rand(4, 4, device='meta'),
|
||||
'x2': torch.rand(4, 4, device='meta')
|
||||
})
|
||||
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
|
||||
|
||||
graph_analyser = GraphAnalyser(gm)
|
||||
liveness_dict = graph_analyser.liveness_analysis()
|
||||
stage_count = len(liveness_dict)
|
||||
|
||||
# 8 stages including input and output
|
||||
assert stage_count == 8
|
||||
|
||||
# a variable named `relu` must exist
|
||||
# and this live var must have inplace = True
|
||||
assert liveness_dict[5].all_live_vars.exists('relu')
|
||||
relu_var = liveness_dict[5].all_live_vars.get('relu')
|
||||
assert relu_var.is_inplace
|
||||
|
||||
# the unique vars must be fewer than the all vars since in-place ops exist
|
||||
all_live_vars = liveness_dict[7].all_live_vars
|
||||
unique_live_vars = liveness_dict[7].unique_live_vars
|
||||
assert len(unique_live_vars) + 1 == len(all_live_vars)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_liveness_analysis()
|
Loading…
Reference in New Issue
Block a user