From 3345c6d3524507b011d3624002483a083afb5df0 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 30 Aug 2022 16:32:09 +0800 Subject: [PATCH] [autoparellel]add strategies constructor (#1505) * [autoparellel]add strategies constructor * remove duplicated strategies * polish code * adapt cost graph with StrategiesConstructor * polish --- colossalai/auto_parallel/solver/constants.py | 22 ++ .../auto_parallel/solver/conv_handler.py | 7 + colossalai/auto_parallel/solver/cost_graph.py | 49 +-- .../auto_parallel/solver/operator_handler.py | 8 +- .../auto_parallel/solver/sharding_strategy.py | 29 +- .../solver/strategies_constructor.py | 355 ++++++++++++++++++ tests/test_auto_parallel/test_cost_graph.py | 97 +++++ .../test_strategies_constructor.py | 98 +++++ 8 files changed, 633 insertions(+), 32 deletions(-) create mode 100644 colossalai/auto_parallel/solver/constants.py create mode 100644 colossalai/auto_parallel/solver/strategies_constructor.py create mode 100644 tests/test_auto_parallel/test_cost_graph.py create mode 100644 tests/test_auto_parallel/test_strategies_constructor.py diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py new file mode 100644 index 000000000..a65b2b173 --- /dev/null +++ b/colossalai/auto_parallel/solver/constants.py @@ -0,0 +1,22 @@ +import torch +import operator + +__all__ = [ + 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', + 'LINEAR_FUNC_OP' +] + +ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] +ELEMENTWISE_FUNC_OP = [ + torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, + operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout +] +CONV_MODULE_OP = [ + torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d +] +CONV_FUNC_OP = [ + torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d +] +LINEAR_MODULE_OP = [torch.nn.Linear] +LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 4c8935809..a00fe1862 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -494,3 +494,10 @@ class ConvHandler(OperatorHandler): self.split_1d_parallel_on_in_channel(0, 1) return self.strategies_vector + + +CONV_STRATEGIES_LIST = [ + 'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', + 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', + 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R' +] diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py index a67ac1c3f..94691397d 100644 --- a/colossalai/auto_parallel/solver/cost_graph.py +++ b/colossalai/auto_parallel/solver/cost_graph.py @@ -39,11 +39,11 @@ class CostGraph: dst_node = strategies_vector.node for src_node in strategies_vector.predecessor_nodes: node_pair = (src_node, dst_node) - src_index = strategies_vector.predecessor_nodes.index(src_node) + # src_index = strategies_vector.predecessor_nodes.index(src_node) edge_cost = {} for i in range(len(strategies_vector)): - for j in range(len(src_node.stategy_vector)): - edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j] + for j in range(len(src_node.strategies_vector)): + edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j] self.edge_costs[node_pair] = edge_cost # add parents and children attribute to node setattr(dst_node, 'parents', strategies_vector.predecessor_nodes) @@ -83,33 +83,19 @@ class CostGraph: merge_map = {} for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): resharding_costs = strategy.resharding_costs - resharding_cost_for_src = resharding_costs[src_node_index] + resharding_cost_for_src = resharding_costs[src_node] lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src)) merge_map[dst_strate_index] = lowest_cost_index # extra_node_cost for dst node - extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] + self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])] for dst_strate_index, strategy in enumerate(dst_node.strategies_vector): target_strate_index = merge_map[dst_strate_index] - extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][ + self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][ target_strate_index] - if src_node in extra_node_costs: - extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index] - - # connect dst node and parents of src node - dst_node.parents.remove(src_node) - src_node.children.remove(dst_node) - node_pair_to_remove = [(src_node, dst_node)] - for parent_node in src_node.parents: - if parent_node not in dst_node.parents: - dst_node.parents.append(parent) - if dst_node not in parent_node.children: - parent_node.children.append(dst_node) - # remove src node from cost graph when src node has no consumer. - if len(src_node.children) == 0: - parent_node.children.remove(src_node) - node_pair = (parent_node, src_node) - self.edge_costs.pop(node_pair) + if src_node in self.extra_node_costs: + self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][ + target_strate_index] # add new node pair to cost graph for parent_node in src_node.parents: @@ -121,9 +107,24 @@ class CostGraph: for i in range(self.node_lens[dst_node]): for j in range(self.node_lens[parent_node]): src_strate_index = merge_map[i] - edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)] + edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)] self.edge_costs[new_node_pair] = edge_cost + # connect dst node and parents of src node + dst_node.parents.remove(src_node) + src_node.children.remove(dst_node) + self.edge_costs.pop((src_node, dst_node)) + for parent_node in src_node.parents: + if parent_node not in dst_node.parents: + dst_node.parents.append(parent_node) + if dst_node not in parent_node.children: + parent_node.children.append(dst_node) + # remove src node from cost graph when src node has no consumer. + if len(src_node.children) == 0: + parent_node.children.remove(src_node) + node_pair = (parent_node, src_node) + self.edge_costs.pop(node_pair) + def simplify_graph(self): if not self.simplify: return diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 85174d9f4..675b71982 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from abc import ABC, abstractmethod from torch.fx.node import Node -from typing import Dict +from typing import Dict, List from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec @@ -56,7 +56,7 @@ class OperatorHandler(ABC): """ pass - def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: + def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: """ Generate the sharding spec of the tensor based on the given dim_partition_dict where the key is the tensor dimension and the value is the mesh dimension for sharding. @@ -84,7 +84,9 @@ class OperatorHandler(ABC): for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): resharding_costs[input_node] = [] for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( - strategy.output_sharding_spec, input_spec) + input_sharding_spec, input_spec) resharding_costs[input_node].append(resharding_cost) return resharding_costs diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 870d7e8dd..9e30bb753 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,7 +1,8 @@ from dataclasses import dataclass from colossalai.tensor.sharding_spec import ShardingSpec -from typing import Dict, List +from typing import Dict, List, Union, Tuple from torch.fx.node import Node +from .constants import * __all__ = ['ShardingStrategy', 'StrategiesVector'] @@ -25,12 +26,15 @@ class ShardingStrategy: ''' name: str - output_sharding_spec: ShardingSpec + # TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor. + output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]] compute_cost: float = 0. communication_cost: float = 0. memory_cost: float = 0. - resharding_costs: Dict[int, List[float]] = None - input_shardings: ShardingSpec = None + resharding_costs: Dict[Node, List[float]] = None + # sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input. + # Therefore, we could process them at the specific op(operator.getitem) + input_shardings: List[ShardingSpec] = None class StrategiesVector(list): @@ -46,8 +50,23 @@ class StrategiesVector(list): super().__init__() self.node = node # fetch its input and output nodes + # TODO: placeholder input nodes self.predecessor_nodes = list(node._input_nodes.keys()) self.successor_nodes = list(node.users.keys()) def check_merge(self): - pass + merge_label = False + if self.node.op == 'call_module': + target = self.node.target + root_module = self.node.graph.owning_module + submod = root_module.get_submodule(target) + submod_type = type(submod) + # merge elementwise module node into following nodes + if submod_type in ELEMENTWISE_MODULE_OP: + merge_label = True + + if self.node.op == 'call_function': + if self.node.target in ELEMENTWISE_FUNC_OP: + merge_label = True + + return merge_label diff --git a/colossalai/auto_parallel/solver/strategies_constructor.py b/colossalai/auto_parallel/solver/strategies_constructor.py new file mode 100644 index 000000000..eca20ef3b --- /dev/null +++ b/colossalai/auto_parallel/solver/strategies_constructor.py @@ -0,0 +1,355 @@ +from torch.fx import Graph, Node +from colossalai.tensor.sharding_spec import ShardingSpec +from .sharding_strategy import ShardingStrategy, StrategiesVector +from .conv_handler import ConvHandler +from .constants import * +from copy import deepcopy +import math +import torch +import operator +from typing import Dict, List + + +class StrategiesConstructor: + + def __init__(self, graph, device_mesh, shape_consistency_manager, solver_options): + self.graph = graph + self.root_module = self.graph.owning_module + self.nodes = list(graph.nodes) + self.device_mesh = device_mesh + self.leaf_strategies = [] + self.strategy_map = {} + self.shape_consistency_manager = shape_consistency_manager + self.solver_options = solver_options + + def _generate_sharding_spec(self, node: Node, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict + where the key is the tensor dimension and the value is the mesh dimension for sharding. + """ + meta_tensor = node._meta_data + sharding_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=meta_tensor.shape, + dim_partition_dict=dim_partition_dict) + return sharding_spec + + def _generate_resharding_costs(self, input_nodes, target_sharding_specs): + ''' + Compute the resharding costs with this specific strategy. + + Argument: + sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. + ''' + resharding_costs = {} + for input_node, target_sharding_spec in zip(input_nodes, target_sharding_specs): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' + _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( + input_sharding_spec, target_sharding_spec) + resharding_costs[input_node].append(resharding_cost) + return resharding_costs + + def remove_duplicated_strategy(self, strategies_vector): + ''' + In build_strategies_and_cost method, we may produce some duplicated strategies. + In this method, we will remove the duplicated strategies depending on the strategies name. + ''' + name_checklist = [] + remove_list = [] + for strategy in strategies_vector: + if strategy.name not in name_checklist: + name_checklist.append(strategy.name) + else: + remove_list.append(strategy) + for strategy in remove_list: + strategies_vector.remove(strategy) + + def build_strategies_and_cost(self): + for node in self.nodes: + strategies_vector = StrategiesVector(node) + # placeholder node + if node.op == 'placeholder': + # For placeholder nodes, if solver_options['fast_mode'] is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the placeholder node. + # Otherwise, all the possible sharding spec for the placeholder node will be enumerated. + + if self.solver_options['fast_mode']: + # create sharding strategy for placeholder + name = 'Replica Placeholder' + dim_partition_dict = {} + output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_placeholder = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_placeholder) + + # get_attr node + if node.op == 'get_attr': + # Same as placeholder nodes, if solver_options['fast_mode'] is True, we just let them in + # fully replicate status, then strategies of following node will be treated equally due + # to replicate status has no resharding cost to other status. At the same time, the searching + # space is smaller than enumerating all the possible sharding spec for the get_attr node. + # Otherwise, all the possible sharding spec for the get_attr node will be enumerated. + if self.solver_options['fast_mode']: + # create sharding strategy for get_attr + name = 'Replica Attribute' + dim_partition_dict = {} + output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost) + strategies_vector.append(sharding_strategy_attribute) + + # call_module node + if node.op == 'call_module': + + target = node.target + submod = self.root_module.get_submodule(target) + submod_type = type(submod) + + # conv module + if submod_type in CONV_MODULE_OP: + # use ConvHandler to create sharding strategies for conv module node + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector, + self.shape_consistency_manager) + conv_handler.register_strategy() + + # linear module + elif submod_type in LINEAR_MODULE_OP: + # use DotHandler to create sharding strategies for linear module node + dot_handler = DotHandler(node, self.device_mesh, strategies_vector, self.shape_consistency_manager) + dot_handler.register_strategy() + + # element-wise module + elif submod_type in ELEMENTWISE_MODULE_OP: + # create sharding strategy for element-wise module + assert len(strategies_vector.predecessor_nodes + ) == 1, f'Temporally, we just support single input element-wise op.' + input_node = strategies_vector.predecessor_nodes[0] + # For element-wise module, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise module. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = node._meta_data.numel() + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_node] = [ + cost if cost == 0 else math.inf for cost in resharding_costs[input_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # other module + else: + raise RuntimeError(f'{submod_type} module is NOT supported now.') + + # call_function node + if node.op == 'call_function': + target = node.target + # conv function + if target in CONV_FUNC_OP: + # use ConvHandler to create sharding strategies for conv node + # TODO: the operator_handler does NOT support function node processing now. + conv_handler = ConvHandler(node, self.device_mesh, strategies_vector, + self.shape_consistency_manager) + conv_handler.register_strategy() + + # linear function + elif target in LINEAR_FUNC_OP: + # use DotHandler to create sharding strategies for linear node + # TODO: the operator_handler does NOT support function node processing now. + linear_handler = DotHandler(node, self.device_mesh, strategies_vector, + self.shape_consistency_manager) + linear_handler.register_strategy() + + # element-wise function + elif target in ELEMENTWISE_FUNC_OP: + # TODO: integrate element-wise func and module together + # create sharding strategy for element-wise function + assert len(strategies_vector.predecessor_nodes + ) == 1, f'Temporally, we just support single input element-wise op.' + input_node = strategies_vector.predecessor_nodes[0] + # For element-wise function, we keep the sharding spec of output node same as + # the input. Therefore, the different strategies of input node with same + # output sharding spec will generate same strategy for element-wise function. + sharding_spec_checklist = [] + for strategy in input_node.strategies_vector: + # It looks a little bit confusing, the input of the processing node + # is the output of the input_node. + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + if input_sharding_spec in sharding_spec_checklist: + continue + sharding_spec_checklist.append(input_sharding_spec) + dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict) + output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict) + name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}' + # TODO: use meta_info_prop to profile memory cost and compute cost + compute_cost = node._meta_data.numel() + memory_cost = 0 + + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_node] = [ + 0 if cost == 0 else math.inf for cost in resharding_costs[input_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # torch.var_mean + elif target == torch.var_mean: + dim = node.kwargs['dim'] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + input_sharding_spec = strategy.output_sharding_spec + assert isinstance(input_sharding_spec, + ShardingSpec), f'The input node should NOT be a tuple of tensor.' + entire_shape_input = input_sharding_spec.entire_shape + dim_partition_dict_input = input_sharding_spec.dim_partition_dict + name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})' + if dim in dim_partition_dict_input: + # We need to make the action dimension in replicate status + dim_partition_dict_for_input = deepcopy(dim_partition_dict_input) + dim_partition_dict_for_input.pop(dim) + new_input_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_input, + dim_partition_dict=dim_partition_dict_for_input) + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [new_input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[new_input_sharding_spec]) + + else: + entire_shape_output = deepcopy(entire_shape_input) + entire_shape_output.pop(dim) + dim_partition_dict_for_output = deepcopy(dim_partition_dict_input) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partion_dict=dim_partition_dict_input) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec), + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_sharding_spec]) + + strategies_vector.append(sharding_strategy) + + # operator.getitem + elif target == operator.getitem: + index = node.args[1] + input_tensor_node = strategies_vector.predecessor_nodes[0] + for strategy in input_tensor_node.strategies_vector: + input_sharding_spec = input_tensor_node.output_sharding_spec[index] + assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.' + dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict) + entire_shape_output = deepcopy(input_sharding_spec.entire_shape) + output_sharding_spec = ShardingSpec(self.device_mesh, + entire_shape_output, + dim_partition_dict=dim_partition_dict_for_output) + # TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec. + compute_cost = 0 + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + [input_sharding_spec]) + # to prevent the resharding happening, set their resharding cost to inf. + resharding_costs[input_tensor_node] = [ + cost if cost == 0 else math.inf for cost in resharding_costs[input_tensor_node] + ] + sharding_strategy = ShardingStrategy(name, + output_sharding_spec, + compute_cost=compute_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=[input_tensor_node.output_sharding_spec]) + strategies_vector.append(sharding_strategy) + + # other function + else: + raise RuntimeError(f'{target} function is NOT supported now.') + + # output node + if node.op == 'output': + if self.solver_options['fast_mode']: + # create sharding strategy for output + name = 'Replica Output' + input_nodes = strategies_vector.predecessor_nodes + input_sharding_specs = [] + for input_node in input_nodes: + dim_partition_dict_for_input = {} + entire_shape = input_node._meta_data.shape + sharding_spec = ShardingSpec(self.device_mesh, + entire_shape, + dim_partition_dict=dim_partition_dict_for_input) + input_sharding_specs.append(sharding_spec) + + dim_partition_dict = {} + output_sharding_spec = input_sharding_specs + # TODO: use meta_info_prop to profile memory cost + memory_cost = 0 + resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, + input_sharding_specs) + sharding_strategy_attribute = ShardingStrategy(name, + output_sharding_spec, + memory_cost=memory_cost, + resharding_costs=resharding_costs) + strategies_vector.append(sharding_strategy_attribute) + + self.remove_duplicated_strategy(strategies_vector) + setattr(node, 'strategies_vector', strategies_vector) + self.leaf_strategies.append(strategies_vector) + self.strategy_map[node] = strategies_vector diff --git a/tests/test_auto_parallel/test_cost_graph.py b/tests/test_auto_parallel/test_cost_graph.py new file mode 100644 index 000000000..bb3e05087 --- /dev/null +++ b/tests/test_auto_parallel/test_cost_graph.py @@ -0,0 +1,97 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from colossalai.auto_parallel.solver.cost_graph import CostGraph +from copy import deepcopy + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3) + self.relu = nn.ReLU() + + def forward(self, x): + x = x * 2 + x = self.conv1(x) + x = x / 2 + x = self.relu(x) + return x + + +def test_cost_graph(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {}) + # %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {}) + # %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {}) + # return relu + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + solver_options = {'fast_mode': True} + strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + strategies_constructor.build_strategies_and_cost() + + # (x, mul): {(0, 0): 0} + # (mul, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0} + # (conv1, truediv): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): 0, (8, 0): inf, (9, 0): 0, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): inf, (5, 1): 0, (6, 1): 0, (7, 1): inf, (8, 1): 0, (9, 1): inf, (10, 1): 0, (11, 1): 0, (12, 1): 0, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): 0, (11, 2): 0, (12, 2): 0, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): 0, (11, 3): 0, (12, 3): 0, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): 0, (11, 4): 0, (12, 4): 0, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): 0, (11, 5): 0, (12, 5): 0, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): 0, (11, 7): 0, (12, 7): 0, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0, (9, 8): inf, (10, 8): 0, (11, 8): 0, (12, 8): 0, (13, 8): inf, (14, 8): 0} + # (truediv, relu): {(0, 0): 0, (1, 0): inf, (2, 0): 0, (3, 0): inf, (4, 0): inf, (5, 0): 0, (6, 0): 0, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): 0, (4, 1): 0, (5, 1): inf, (6, 1): 0, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): 0, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): 0, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): 0, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): 0, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): 0, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): 0, (5, 8): inf, (6, 8): 0, (7, 8): inf, (8, 8): 0} + # (relu, output): {(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002} + cost_graph = CostGraph(strategies_constructor.leaf_strategies) + + # construct all node pairs + all_node_pairs = [] + + for node in graph.nodes: + if node.op == 'output': + continue + all_node_pairs.append((node, node.next)) + + for node_pair in all_node_pairs: + assert node_pair in cost_graph.edge_costs + + # construct merged node pairs + merged_node_pairs = [] + node_list = list(graph.nodes) + + # add (x, conv) and (conv, output) into check node pairs + merged_node_pairs.append((node_list[0], node_list[2])) + merged_node_pairs.append((node_list[2], node_list[-1])) + # (x, conv1): {(0, 0): 0, (0, 1): 0, (0, 2): 0, (0, 3): 0, (0, 4): 0, (0, 5): 0, (0, 6): 0, (0, 7): 0, (0, 8): 0, (0, 9): 0, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 0, (0, 14): 0} + # (conv1, output): {(0, 0): inf, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): 0, (11, 0): 0, (12, 0): 0, (13, 0): inf, (14, 0): inf} + cost_graph.simplify_graph() + for node_pair in all_node_pairs: + if node_pair in merged_node_pairs: + assert node_pair in cost_graph.edge_costs + else: + assert node_pair not in cost_graph.edge_costs + + +if __name__ == '__main__': + test_cost_graph() diff --git a/tests/test_auto_parallel/test_strategies_constructor.py b/tests/test_auto_parallel/test_strategies_constructor.py new file mode 100644 index 000000000..41a7e8bd7 --- /dev/null +++ b/tests/test_auto_parallel/test_strategies_constructor.py @@ -0,0 +1,98 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh +from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor +from copy import deepcopy + + +class ConvModel(nn.Module): + + def __init__(self, c_in, c_out): + super().__init__() + self.conv = nn.Conv2d(c_in, c_out, kernel_size=3) + + def forward(self, x): + x = x * 2 + x = self.conv(x) + return x + + +def test_strategies_constructor(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 16, 64, 64)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = ConvModel(16, 32) + input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + + solver_options = {'fast_mode': True} + strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options) + + assert strategies_constructor.leaf_strategies == [] + assert strategies_constructor.strategy_map == {} + strategies_constructor.build_strategies_and_cost() + + # check leaf_strategies + + # In fast mode, placeholder node only has replica strategy. + assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder' + + # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. + assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]' + + # Third node is conv. + conv_check_list = deepcopy(CONV_STRATEGIES_LIST) + for strategy in strategies_constructor.leaf_strategies[2]: + conv_check_list.remove(strategy.name) + assert len(conv_check_list) == 0 + + # In fast mode, output node only has replica strategy. + assert strategies_constructor.leaf_strategies[3][0].name == 'Replica Output' + + # check strategy_map + + nodes = [node for node in graph.nodes] + # In fast mode, placeholder node only has replica strategy. + x = nodes[0] + assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder' + + # Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec. + mul = nodes[1] + assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]' + + # Third node is conv. + conv = nodes[2] + conv_check_list = deepcopy(CONV_STRATEGIES_LIST) + for strategy in strategies_constructor.strategy_map[conv]: + conv_check_list.remove(strategy.name) + assert len(conv_check_list) == 0 + + # In fast mode, output node only has replica strategy. + output = nodes[3] + assert strategies_constructor.strategy_map[output][0].name == 'Replica Output' + + +if __name__ == '__main__': + test_strategies_constructor()