diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py index a27c1d065..c20ed18ca 100644 --- a/colossalai/auto_parallel/solver/__init__.py +++ b/colossalai/auto_parallel/solver/__init__.py @@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler from .dot_handler import DotHandler from .conv_handler import ConvHandler from .sharding_strategy import ShardingStrategy, StrategiesVector +from .graph_analysis import GraphAnalyser -__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy'] +__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser'] diff --git a/colossalai/auto_parallel/solver/graph_analysis.py b/colossalai/auto_parallel/solver/graph_analysis.py new file mode 100644 index 000000000..53469c246 --- /dev/null +++ b/colossalai/auto_parallel/solver/graph_analysis.py @@ -0,0 +1,174 @@ +from dataclasses import dataclass +from torch.fx.node import Node +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from collections import OrderedDict as ODict +from typing import List, OrderedDict, Union, Any +from colossalai.fx.passes.utils import get_node_module + +__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser'] + + +@dataclass +class LiveVariable: + """ + LiveVariable is a data structure to store the meta information of a variable for liveness analysis. + """ + name: str + meta: Union[Any, List[Any]] + is_inplace: bool + + +class LiveVariableVector(list): + """ + LiveVariableVector is a data structure to store the list of LiveVariable objects. + """ + + def exists(self, name) -> bool: + """ + Check if a variable has already existed in the current list by name. + """ + for var in self: + if name == var.name: + return True + return False + + def get(self, name) -> LiveVariable: + for var in self: + if name == var.name: + return var + raise KeyError(f"Variable {name} is not found") + + def copy(self) -> "LiveVariableVector": + """ + Create a copy of this vector + """ + vector = LiveVariableVector() + for var in self: + vector.append(var) + return vector + + +@dataclass +class LiveStage: + """ + LiveStage is a data structure to record the living variables at this current node. + """ + name: str + node: Node + all_live_vars: LiveVariableVector + unique_live_vars: LiveVariableVector + + +class GraphAnalyser: + + def __init__(self, gm: GraphModule): + self._gm = gm + self._graph = gm.graph + + @property + def gm(self) -> GraphModule: + """ + Return the GraphModule object associated with this analyser. + """ + return self._gm + + @property + def graph(self) -> Graph: + """ + Return the Graph object associated with this analyser. + """ + return self._graph + + def liveness_analysis(self) -> OrderedDict[int, LiveStage]: + """ + Analyse the graph to obtain the variable liveness information. This function returns + an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. + """ + compute_nodes = self.graph.nodes + liveness_dict = ODict() + + # checked: record all variables created since the first stage + # all: record the live variables only exist until the current stage. + # this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage. + # unique: record the unique live variables only exist until the current stage. + # this is different from `all list` as some variables are duplicated. + checked_variables = LiveVariableVector() + all_live_variables = LiveVariableVector() + unique_live_vars = LiveVariableVector() + + def _add_param_or_buf(node, tensor_type): + module = get_node_module(node) + + if tensor_type == 'param': + iterator = module.named_parameters() + elif tensor_type == 'buffer': + iterator = module.named_buffers() + else: + raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}") + + for name, tensor in iterator: + tensor_name = f'{node.name}.{name}' + + if not checked_variables.exists(tensor_name): + live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False) + unique_live_vars.append(live_tensor) + checked_variables.append(live_tensor) + all_live_variables.append(live_tensor) + + for idx, node in enumerate(compute_nodes): + ############################# + # find new living variables # + ############################# + # detect whether the current op is an in-place op + # if it is an in-place op, we would deem it as a duplciate var + is_inplace = False + if node.op == 'call_function': + # check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True) + if node.kwargs.get('inplace', False): + is_inplace = True + elif node.op == 'call_module': + # to check if this is an inplace op such as torch.nn.Relu(inplace=True) + module = get_node_module(node) + if getattr(module, 'inplace', False): + is_inplace = True + + # add the output var + meta = getattr(node, '_meta_data', None) + live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace) + if not is_inplace: + unique_live_vars.append(live_var) + checked_variables.append(live_var) + all_live_variables.append(live_var) + + # add the model parameters + if node.op == 'call_module': + _add_param_or_buf(node, tensor_type='param') + _add_param_or_buf(node, tensor_type='buffer') + + # add this output variable to the checked list + checked_variables.append(live_var) + + # check if any input is not checked yet + for arg in node.args: + arg_name = str(arg) + if not checked_variables.exists(arg_name): + meta = getattr(node, '_meta_data', None) + live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False) + all_live_variables.append(live_var_from_arg) + checked_variables.append(live_var_from_arg) + unique_live_vars.append(live_var_from_arg) + + # TODO: add the logic to remove live variables + # this should be completed if we are able to trace the backward compute graph + + # add this stage to liveness dict + stage = LiveStage(name=node.name, + node=node, + all_live_vars=all_live_variables.copy(), + unique_live_vars=unique_live_vars.copy()) + liveness_dict[idx] = stage + return liveness_dict + + def get_alias_set(self): + pass diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py index d3e38c190..842c9d52e 100644 --- a/colossalai/fx/passes/utils.py +++ b/colossalai/fx/passes/utils.py @@ -160,3 +160,20 @@ def assign_bfs_level_to_nodes(graph: Graph): new_process_list.extend(get_all_consumers(graph, node)) nodes_to_process = new_process_list current_level += 1 + + +def get_node_module(node) -> torch.nn.Module: + """ + Find the module associated with the given node. + + Args: + node (torch.fx.Node): a torch.fx.Node object in the fx computation graph + + Returns: + torch.nn.Module: the module associated with the given node + """ + + assert node.graph.owning_module is not None, 'Cannot find the owning_module for node.graph, please make sure the graph is associated with a GraphModule object' + assert node.op == 'call_module', f'Expected node.op to be call_module, but found {node.op}' + module = node.graph.owning_module.get_submodule(node.target) + return module diff --git a/tests/test_auto_parallel/test_liveness_analysis.py b/tests/test_auto_parallel/test_liveness_analysis.py new file mode 100644 index 000000000..36039382f --- /dev/null +++ b/tests/test_auto_parallel/test_liveness_analysis.py @@ -0,0 +1,54 @@ +import torch.nn as nn +import torch +from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser +from colossalai.fx import ColoTracer, ColoGraphModule + + +class LinearModel(nn.Module): + + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(4, 4) + self.relu = nn.ReLU(inplace=True) + self.linear2 = nn.Linear(4, 4) + + def forward(self, x1, x2): + x1 = x1 * 2 + x1 = self.linear1(x1) + x1 = self.relu(x1) + x1 = self.linear2(x1) + out = x1 + x2 + return out + + +def test_liveness_analysis(): + model = LinearModel() + tracer = ColoTracer() + graph = tracer.trace(model, + meta_args={ + 'x1': torch.rand(4, 4, device='meta'), + 'x2': torch.rand(4, 4, device='meta') + }) + gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) + + graph_analyser = GraphAnalyser(gm) + liveness_dict = graph_analyser.liveness_analysis() + stage_count = len(liveness_dict) + + # 8 stages including input and output + assert stage_count == 8 + + # a variable named `relu` must exist + # and this live var must have inplace = True + assert liveness_dict[5].all_live_vars.exists('relu') + relu_var = liveness_dict[5].all_live_vars.get('relu') + assert relu_var.is_inplace + + # the unique vars must be fewer than the all vars since in-place ops exist + all_live_vars = liveness_dict[7].all_live_vars + unique_live_vars = liveness_dict[7].unique_live_vars + assert len(unique_live_vars) + 1 == len(all_live_vars) + + +if __name__ == '__main__': + test_liveness_analysis()