diff --git a/colossalai/auto_parallel/tensor_shard/utils/factory.py b/colossalai/auto_parallel/tensor_shard/utils/factory.py index fd3ba3d41..05331e560 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/factory.py +++ b/colossalai/auto_parallel/tensor_shard/utils/factory.py @@ -1,13 +1,16 @@ +import copy import operator import warnings from functools import reduce from typing import Dict, List, Optional, Union import torch +from torch.fx.node import Node +from torch.utils._pytree import tree_map + from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec -from torch.fx.node import Node from ..constants import INFINITY_COST @@ -18,7 +21,7 @@ def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: Devic dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict. - + Args: input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node. @@ -59,7 +62,7 @@ def generate_resharding_costs(nodes: List[Node], nodes (List[Node]): a list of nodes sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes. count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference. - dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. + dtype (Optional[torch.dtype]): the data type for cost calculation, default is None. ''' # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} @@ -88,3 +91,116 @@ def generate_resharding_costs(nodes: List[Node], resharding_cost = INFINITY_COST resharding_costs[input_node].append(resharding_cost) return resharding_costs + + +def find_repeat_blocks(node_list: List[torch.fx.Node], root_module, common_length_threshold: int = 20): + ''' + Find the largest repeat blocks in the graph, whose length is larger than the threshold. + + Args: + gm (GraphModule): the graph module to be analyzed. + common_length_threshold (int): the threshold of the repeat block length. + ''' + + # graph = gm.graph + + def _process_args(args): + new_args = [] + for arg in args: + if hasattr(arg, '_meta_data'): + meta_data = arg._meta_data + else: + meta_data = arg + + def _process_arg(data): + if isinstance(data, torch.Tensor): + data = data.size() + elif isinstance(data, slice): + data = (data.start, data.step, data.stop) + return data + + new_meta_data = tree_map(_process_arg, meta_data) + new_args.append(new_meta_data) + + return new_args + + def _all_equal(check_list, check_fn): + base_value = check_list[-1] + for e in check_list: + if not check_fn(e, base_value): + return False + return True + + def _check_node_list_equal(l1, l2): + if len(l1) != len(l2): + return False + for node1, node2 in zip(l1, l2): + if hash(node1.hash_key) != hash(node2.hash_key): + return False + return True + + def _check_node_equal(node1, node2): + if hash(node1.hash_key) == hash(node2.hash_key): + return True + return False + + for index, node in enumerate(node_list): + if node.op == 'call_module': + target = node.target + submod = root_module.get_submodule(target) + submod_type = type(submod) + target = submod_type + else: + target = node.target + + new_args = _process_args(node.args) + + if node.op != 'get_attr': + hash_key = (node.op, target, *new_args) + else: + hash_key = (node.op,) + + setattr(node, 'hash_key', hash_key) + + hash_value_to_node_dict = {} + + for index, node in enumerate(node_list): + hash_value = hash(node.hash_key) + if hash_value not in hash_value_to_node_dict: + hash_value_to_node_dict[hash_value] = [] + hash_value_to_node_dict[hash_value].append(index) + + # node_list = list(graph.nodes) + + node_list_start = 0 + max_common_length = common_length_threshold + common_blocks_index = [] + for index, node in enumerate(node_list): + # the comparison will be triggered if a common node appears + if len(hash_value_to_node_dict[hash(node.hash_key)]) >= 2: + start_index_list = hash_value_to_node_dict[hash(node.hash_key)] + check_block_list = [node_list[start:start + max_common_length] for start in start_index_list] + + common_label = True + if not _all_equal(check_block_list, _check_node_list_equal): + common_label = False + + if common_label: + common_blocks_index = copy.deepcopy(start_index_list) + max_step = len(node_list) - common_blocks_index[-1] - max_common_length - 1 + + for i in range(max_step): + # add assertion to avoid out of index + next_node_list = [node_list[index + max_common_length + i] for index in start_index_list] + if not _all_equal(next_node_list, _check_node_equal): + max_step = i + break + max_common_length += max_step + node_list_start += max_common_length + + # recover common subgraph from the index + common_blocks = [] + for start in common_blocks_index: + common_blocks.append(node_list[start:start + max_common_length]) + + return common_blocks diff --git a/tests/test_auto_parallel/test_pass/__init__.py b/tests/test_auto_parallel/test_pass/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py new file mode 100644 index 000000000..90301521f --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_find_repeat_block.py @@ -0,0 +1,110 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from torch.fx import GraphModule +from transformers.pytorch_utils import Conv1D + +from colossalai.auto_parallel.tensor_shard.utils.factory import find_repeat_blocks +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.testing import parameterize +from colossalai.testing.pytest_wrapper import run_on_environment_flag + +NUM_REPEAT_BLOCKS = 4 +BATCH_SIZE = 1 +SEQ_LENGTH = 32 +HIDDEN_DIM = 384 + + +class RepeatBlock(nn.Module): + + def __init__(self, intermediate_size, hidden_size): + super().__init__() + self.c_fc = Conv1D(intermediate_size, hidden_size) + self.c_proj = Conv1D(hidden_size, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states + + +class RepeatModel(nn.Module): + + def __init__(self, intermediate_size, hidden_size, num_layers): + super().__init__() + self.blocks = nn.ModuleList([RepeatBlock(intermediate_size, hidden_size) for i in range(num_layers)]) + + def forward(self, x): + + for block in self.blocks: + x = block(x) + + return x + + +class NonRepeatBlock(nn.Module): + + def __init__(self, intermediate_size, hidden_size, layer_index): + super().__init__() + intermediate_size //= (layer_index + 1) + self.c_fc = Conv1D(intermediate_size, hidden_size) + self.c_proj = Conv1D(hidden_size, intermediate_size) + self.act = torch.nn.ReLU() + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + + return hidden_states + + +class NonRepeatModel(nn.Module): + + def __init__(self, intermediate_size, hidden_size, num_layers): + super().__init__() + self.blocks = nn.ModuleList([NonRepeatBlock(intermediate_size, hidden_size, i) for i in range(num_layers)]) + + def forward(self, x): + + for block in self.blocks: + x = block(x) + + return x + + +@run_on_environment_flag(name='AUTO_PARALLEL') +@parameterize('model_cls', [RepeatModel, NonRepeatModel]) +def test_repeat_blocks(model_cls): + + model = model_cls(4 * HIDDEN_DIM, HIDDEN_DIM, NUM_REPEAT_BLOCKS) + + tracer = ColoTracer() + input_sample = {'x': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta')} + graph = tracer.trace(root=model, meta_args=input_sample) + + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + node_list = list(graph.nodes) + root_module = graph.owning_module + common_blocks = find_repeat_blocks(node_list, root_module, common_length_threshold=10) + + total_num_nodes = len(list(graph.nodes)) + # remove the input placeholder node and the output node + num_repeat_nodes_per_block = (total_num_nodes - 2) // NUM_REPEAT_BLOCKS + for common_block in common_blocks: + print(common_block) + if model_cls == RepeatModel: + assert len(common_blocks) == NUM_REPEAT_BLOCKS + assert len(common_blocks[0]) == num_repeat_nodes_per_block + elif model_cls == NonRepeatModel: + assert len(common_blocks) == 0 + + +if __name__ == '__main__': + test_repeat_blocks()