diff --git a/colossalai/nn/graph/__init__.py b/colossalai/nn/graph/__init__.py deleted file mode 100644 index 0cfecf8b4..000000000 --- a/colossalai/nn/graph/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .utils import register_colo_graph -from .graph_node import GraphContext, GraphGlobalEnv, GraphOpNode - -__all__ = ['register_colo_graph', 'GraphContext', 'GraphGlobalEnv', 'GraphOpNode'] \ No newline at end of file diff --git a/colossalai/nn/graph/graph_node.py b/colossalai/nn/graph/graph_node.py deleted file mode 100644 index 32653ad98..000000000 --- a/colossalai/nn/graph/graph_node.py +++ /dev/null @@ -1,96 +0,0 @@ -from colossalai.tensor import ColoTensor -from colossalai.context.singleton_meta import SingletonMeta - - -class GraphGlobalEnv(metaclass=SingletonMeta): - - def __init__(self) -> None: - self.graph_building = False - self.graph_node_list = [] - self.node_id = -1 - - def get_node_id(self): - self.node_id += 1 - return self.node_id - - def add_graph_node(self, node): - self.graph_node_list.append(node) - - -class GraphContext(): - """ - - Building the computing graph under the context - - >>> with GraphContext(): - >>> output = model(colo_input_tensor) - """ - graph_nodes = [] - - def __enter__(self): - GraphGlobalEnv().graph_building = True - GraphGlobalEnv().graph_node_list = [] - - def __exit__(self, *exc_info): - GraphGlobalEnv().graph_building = False - GraphGlobalEnv().node_id = -1 - self.graph_nodes = GraphGlobalEnv().graph_node_list - - -class GraphNode(object): - - def __init__(self) -> None: - self.prev_nodes = [] - self.post_nodes = [] - self.id = GraphGlobalEnv().get_node_id() - - def add_prev_node(self, node): - if GraphGlobalEnv().graph_building: - self.prev_nodes.append(node) - - def add_post_node(self, node): - if GraphGlobalEnv().graph_building: - self.post_nodes.append(node) - - def post_node_empty(self) -> bool: - return len(self.post_nodes) == 0 - - -class GraphOpNode(GraphNode): - - def __init__(self, op_type, param_list) -> None: - super().__init__() - self._op_type = op_type - self._param_list = param_list - GraphGlobalEnv().add_graph_node(self) - - def add_prev_tensor(self, colo_tensor: ColoTensor): - r""" - Link the current graph op node to previous graph op. - Op1 <- Activation (colo_tensor) Op2 - Op1 <- Op2 - """ - if GraphGlobalEnv().graph_building: - assert isinstance(colo_tensor, ColoTensor) - if colo_tensor._graph_node is None: - colo_tensor._graph_node = GraphNode() - prev_ops = colo_tensor._graph_node.prev_nodes - for op_node in prev_ops: - self.add_prev_node(op_node) - op_node.add_post_node(self) - - def add_post_tensor(self, colo_tensor: ColoTensor): - """ - Op <- Activation (colo_tensor) - """ - if GraphGlobalEnv().graph_building: - assert isinstance(colo_tensor, ColoTensor), f'type {type(colo_tensor)}' - if colo_tensor._graph_node is None: - colo_tensor._graph_node = GraphNode() - - colo_tensor._graph_node.add_prev_node(self) - - def print(self): - print( - f'GraphOpNode {self._op_type} {self.id}, post nodes {[node.id for node in self.post_nodes]}, prev node number {[node.id for node in self.prev_nodes]}' - ) diff --git a/colossalai/nn/graph/utils.py b/colossalai/nn/graph/utils.py deleted file mode 100644 index 1070319ca..000000000 --- a/colossalai/nn/graph/utils.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import torch -from colossalai.tensor import ColoTensor -from typing import Callable, List -from colossalai.nn._ops._utils import convert_to_colo_tensor - - -def register_colo_graph(input_pos: List[int], param_pos: List[int]) -> Callable: - """register_colo_graph - Register a Op (Layer) to ColoGraph. - Recoders the input args in types of ColoTensor to the Graph. - - Args: - func (Callable): a function implements the Op. - - Returns: - Callable: wrapper function. - """ - - def register_colo_graph_decorator(func): - from colossalai.nn.graph import GraphOpNode, GraphGlobalEnv - - @functools.wraps(func) - def wrapper(*args, **kwargs): - param_list = [] - input_list = [] - # TODO(jiaruifang) find the pg - for idx, arg in enumerate(args): - if isinstance(arg, torch.Tensor) and idx in input_pos: - input_list.append(convert_to_colo_tensor(arg)) - if isinstance(arg, torch.Tensor) and idx in param_pos: - param_list.append(convert_to_colo_tensor(arg)) - # building the computing graph, inputs -> op - if GraphGlobalEnv().graph_building: - cur_op_node = GraphOpNode('linear', param_list) - # TODO supports a list of ColoTensor as args - if len(input_list) > 0: - cur_op_node.add_prev_tensor(input_list[0]) - - outputs = func(*args, **kwargs) - - # building the computing graph, op -> output - if GraphGlobalEnv().graph_building: - # TODO supports a list of ColoTensor as args - if isinstance(outputs[0], ColoTensor): - cur_op_node.add_post_tensor(outputs[0]) - return outputs - - return wrapper - - return register_colo_graph_decorator