diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py new file mode 100644 index 000000000..0eb7f32f4 --- /dev/null +++ b/colossalai/fx/passes/meta_info_prop.py @@ -0,0 +1,101 @@ +import torch +import torch.fx +from torch.fx.node import Node, map_aggregate +from typing import Any, Tuple, NamedTuple, Optional, Dict +from functools import reduce +from torch.fx._compatibility import compatibility + + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: Tuple[int] + numel: int + # TODO: we can add a list of sharding spec here, and record the sharding + # behaviour by appending sharding spec into list. + + +def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() + numel = result.numel() + + return TensorMetadata(shape, dtype, requires_grad, stride, numel) + + +@compatibility(is_backward_compatible=True) +class MetaInfoProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and + record the shape and type of the result + into the corresponding node. + + Usage: + BATCH_SIZE = 2 + DIM_IN = 4 + DIM_OUT = 16 + model = torch.nn.Linear(DIM_IN, DIM_OUT) + input_sample = torch.rand(BATCH_SIZE, DIM_IN) + orig_output = model(input_sample) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + + for node in gm.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, + node.meta['tensor_meta'].shape, node.meta['tensor_meta'].numel) + + # output of above code is + # input_1 torch.float32 torch.Size([2, 4]) 8 + # weight torch.float32 torch.Size([16, 4]) 64 + # bias torch.float32 torch.Size([16]) 16 + # linear torch.float32 torch.Size([2, 16]) 32 + # output torch.float32 torch.Size([2, 16]) 32 + Args: + module (GraphModule): The module to be executed + + """ + + def run_node(self, n: Node) -> Any: + result = super().run_node(n) + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta['tensor_meta'] = meta + else: + n.meta['tensor_meta'] = TensorMetadata(None, None, False, None, 0) + + n.meta['type'] = type(result) + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + return super().run(*args) diff --git a/colossalai/fx/passes/utils.py b/colossalai/fx/passes/utils.py new file mode 100644 index 000000000..fb8d029b7 --- /dev/null +++ b/colossalai/fx/passes/utils.py @@ -0,0 +1,27 @@ +import torch +from typing import Dict, Set +from torch.fx.node import Node, map_arg + + +def get_comm_size(prev_partition, next_partition): + """Given two partitions (parent and child), + calculate the communication size between the two. + """ + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + parent_node_names = [n.name for n in parent_partition.graph.nodes] + for node in child_partition.graph.nodes: + input_nodes: Dict[Node, None] = {} + map_arg(node.args, lambda n: input_nodes.setdefault(n)) + map_arg(node.kwargs, lambda n: input_nodes.setdefault(n)) + for n in input_nodes: + if n.name in parent_node_names and n not in visited_nodes: + comm_size += n.meta['tensor_meta'].numel + visited_nodes.add(n) + return comm_size diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py new file mode 100644 index 000000000..bc040bcca --- /dev/null +++ b/tests/test_fx/test_comm_size_compute.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn +from torch.fx import symbolic_trace +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass +from colossalai.fx.passes.utils import get_comm_size + +MODEL_DIM = 16 +BATCH_SIZE = 8 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int): + super().__init__() + self.linear1 = torch.nn.Linear(dim, dim) + self.linear2 = torch.nn.Linear(dim, dim) + self.linear3 = torch.nn.Linear(dim, dim) + self.linear4 = torch.nn.Linear(dim, dim) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + return x + + +def test_comm_size_compute(): + model = MLP(MODEL_DIM) + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) + split_model, split_submodules = split_with_split_nodes_pass(annotated_model) + submodule_list = list(split_model.children()) + comm_size = get_comm_size(submodule_list[0], submodule_list[1]) + # the shape of tensor send from partition 0 to partition 1 is (8, 16) + assert comm_size == 128 + + +if __name__ == '__main__': + test_comm_size_compute() diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py new file mode 100644 index 000000000..84cef23b0 --- /dev/null +++ b/tests/test_fx/test_meta_info_prop.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn +import colossalai +import colossalai.nn as col_nn +from torch.fx import symbolic_trace +from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata + +BATCH_SIZE = 2 +DIM_IN = 4 +DIM_OUT = 16 + + +def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): + assert meta_info_spec.shape == orig_tensor.shape + assert meta_info_spec.dtype == orig_tensor.dtype + assert meta_info_spec.requires_grad == orig_tensor.requires_grad + assert meta_info_spec.stride == orig_tensor.stride() + assert meta_info_spec.numel == orig_tensor.numel() + + +def test_meta_info_prop(): + model = torch.nn.Linear(DIM_IN, DIM_OUT) + input_sample = torch.rand(BATCH_SIZE, DIM_IN) + orig_output = model(input_sample) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + for node in gm.graph.nodes: + if node.op == 'placeholder': + meta_check(node.meta['tensor_meta'], input_sample) + if node.op == 'output': + meta_check(node.meta['tensor_meta'], orig_output) + + +if __name__ == '__main__': + test_meta_info_prop()