diff --git a/colossalai/auto_parallel/solver/__init__.py b/colossalai/auto_parallel/solver/__init__.py index e69de29bb..a27c1d065 100644 --- a/colossalai/auto_parallel/solver/__init__.py +++ b/colossalai/auto_parallel/solver/__init__.py @@ -0,0 +1,6 @@ +from .operator_handler import OperatorHandler +from .dot_handler import DotHandler +from .conv_handler import ConvHandler +from .sharding_strategy import ShardingStrategy, StrategiesVector + +__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy'] diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 1a816eb13..4e0d46104 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -1,17 +1,20 @@ import operator from functools import reduce import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy -from .operator_handler import OperatorHanlder +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler -class ConvHandler(OperatorHanlder): +class ConvHandler(OperatorHandler): """ A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.output_data = self.node._meta_data self._sanity_check() def _sanity_check(self): @@ -42,7 +45,7 @@ class ConvHandler(OperatorHanlder): # 1D: (L) * N * Cout * Cin * kernel # 2D: (H * W) * N * Cout * Cin * kernel # 3D: (H * W * D) * N * Cout * Cin * kernel - output_size = self.output.shape[2:] + output_size = self.output_data.shape[2:] output_size_product = reduce(operator.mul, output_size, 1) kernel_size = self.weight.shape[2:] kernel_size_product = reduce(operator.mul, kernel_size, 1) @@ -59,11 +62,10 @@ class ConvHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] @@ -73,7 +75,7 @@ class ConvHandler(OperatorHanlder): # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -87,7 +89,7 @@ class ConvHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' @@ -99,11 +101,10 @@ class ConvHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0] @@ -113,7 +114,7 @@ class ConvHandler(OperatorHanlder): # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -127,7 +128,7 @@ class ConvHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -139,11 +140,10 @@ class ConvHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -153,7 +153,7 @@ class ConvHandler(OperatorHanlder): # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -167,7 +167,7 @@ class ConvHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_weight_out_channel(self, mesh_dim_0): name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' @@ -179,11 +179,10 @@ class ConvHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -193,7 +192,7 @@ class ConvHandler(OperatorHanlder): # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -208,7 +207,7 @@ class ConvHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def non_split(self): name = f'RR = RR x RR' @@ -220,11 +219,10 @@ class ConvHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy bs = self.input_data.shape[0] @@ -234,7 +232,7 @@ class ConvHandler(OperatorHanlder): # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() memory_cost = numel * size_per_elem_bytes @@ -248,9 +246,9 @@ class ConvHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) - def register_strategy_into_strategies_vector(self): + def register_strategy(self) -> StrategiesVector: ''' Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. @@ -315,3 +313,5 @@ class ConvHandler(OperatorHanlder): # RR= RR x RR self.non_split() + + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/dot_handler.py index db99beb21..62d90c570 100644 --- a/colossalai/auto_parallel/solver/dot_handler.py +++ b/colossalai/auto_parallel/solver/dot_handler.py @@ -1,15 +1,21 @@ import operator import torch -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy -from .operator_handler import OperatorHanlder +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from .operator_handler import OperatorHandler from functools import reduce -class DotHandler(OperatorHanlder): +class DotHandler(OperatorHandler): """ A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.input_data = self.predecessor_node[0]._meta_data + self.weight = self.module_named_parameters['weight'] + self.output_data = self.node._meta_data + def _generate_compute_cost(self, input_shape, weight_shape): # TODO: consider bias addition compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 @@ -27,18 +33,17 @@ class DotHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute computation cost compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -55,7 +60,7 @@ class DotHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR @@ -70,18 +75,17 @@ class DotHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {0: [mesh_dim_0]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -95,7 +99,7 @@ class DotHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' @@ -107,18 +111,17 @@ class DotHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim_1]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim_0] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -132,7 +135,7 @@ class DotHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def recompute_split_both_contract(self, mesh_dim): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' @@ -144,18 +147,17 @@ class DotHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() memory_cost = numel * size_per_elem_bytes @@ -168,7 +170,7 @@ class DotHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) def split_rhs_space_only(self, mesh_dim): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' @@ -180,18 +182,17 @@ class DotHandler(OperatorHanlder): sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) dim_partition_dict_for_output = {1: [mesh_dim]} - sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output) # generate resharding cost for this strategy - resharding_costs = {} - self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + resharding_costs = self._generate_resharding_costs([sharding_spec_for_input]) # compute the computation cost of this strategy compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) # compute the memory cost of this strategy dtype = self.input_data.dtype - numel = self.output.numel() + numel = self.output_data.numel() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() sharding_size = self.device_mesh.shape[mesh_dim] memory_cost = numel * size_per_elem_bytes / sharding_size @@ -205,9 +206,9 @@ class DotHandler(OperatorHanlder): memory_cost=memory_cost, resharding_costs=resharding_costs, input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) - self.strategies_vector.strategies.append(sharding_strategies) + self.strategies_vector.append(sharding_strategies) - def register_strategy_into_strategies_vector(self): + def register_strategy(self) -> StrategiesVector: ''' Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. @@ -233,3 +234,4 @@ class DotHandler(OperatorHanlder): # RS = RR x RS self.split_rhs_space_only(0) self.split_rhs_space_only(1) + return self.strategies_vector diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 1331b306f..26dcfd892 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -1,15 +1,18 @@ +import torch +import torch.nn as nn from abc import ABC, abstractmethod from torch.fx.node import Node -import torch.nn as nn +from typing import Dict from colossalai.device.device_mesh import DeviceMesh -from .sharding_strategy import StrategiesVector from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.sharding_spec import ShardingSpec +from .sharding_strategy import StrategiesVector -class OperatorHanlder(ABC): + +class OperatorHandler(ABC): ''' - The OperatorHanlder is an abstract class used to generate every possible strategies for a operator node. + The OperatorHandler is an abstract class used to generate every possible strategies for a operator node. Argument: input_node(Node): the input node in node argument list. @@ -21,30 +24,43 @@ class OperatorHanlder(ABC): shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs. ''' - def __init__(self, input_node: Node, input_index: int, weight: nn.Parameter, output_node: Node, - device_mesh: DeviceMesh, strategies_vector: StrategiesVector, + def __init__(self, node: Node, device_mesh: DeviceMesh, strategies_vector: StrategiesVector, shape_consistency_manager: ShapeConsistencyManager): - self.input_node = input_node - self.input_data = self.input_node._meta_data - self.weight = weight - self.input_index = input_index - self.output_node = output_node - self.output = self.output_node._meta_data + self.node = node + self.predecessor_node = list(node._input_nodes.keys()) + self.successor_node = list(node.users.keys()) self.device_mesh = device_mesh self.strategies_vector = strategies_vector self.shape_consistency_manager = shape_consistency_manager + # find the module and its parameters associated with this node + # this can be used to compute the compute/communication/sharding cost + if self.node.op == 'call_module': + module = node.graph.owning_module.get_submodule(node.target) + named_parameters = list(module.named_parameters(recurse=False)) + # convert named parameters from list to dict + named_parameters = {k: v for k, v in named_parameters} + else: + module = None + named_parameters = None + self.module = module + self.module_named_parameters = named_parameters + @abstractmethod - def register_strategy_into_strategies_vector(self): + def register_strategy(self) -> StrategiesVector: pass - def _generate_sharding_spec(self, tensor, dim_partition_dict): + def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: + """ + Generate the sharding spec of the tensor based on the given dim_partition_dict + where the key is the tensor dimension and the value is the mesh dimension for sharding. + """ sharding_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=tensor.shape, dim_partition_dict=dim_partition_dict) return sharding_spec - def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): + def _generate_resharding_costs(self, sharding_spec_for_input): ''' Compute the resharding costs with this specific strategy. @@ -58,8 +74,10 @@ class OperatorHanlder(ABC): sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. ''' # The resharding_cost of weight is counted due to sharing weight cases. - resharding_costs[self.input_index] = [] - for stategy in self.input_node.strategies_vector.strategies: - _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input) - resharding_costs[self.input_index].append(resharding_cost) + resharding_costs = {} + for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): + resharding_costs[input_node] = [] + for strategy in input_node.strategies_vector: + _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec) + resharding_costs[input_node].append(resharding_cost) return resharding_cost diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 6c465d0c7..b6eb2e220 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -1,6 +1,9 @@ from dataclasses import dataclass from colossalai.tensor.sharding_spec import ShardingSpec from typing import Dict, List +from torch.fx.node import Node + +__all__ = ['ShardingStrategy', 'StrategiesVector'] @dataclass @@ -30,26 +33,21 @@ class ShardingStrategy: input_shardings: ShardingSpec = None -class StrategiesVector: +class StrategiesVector(list): ''' Each node in fx graph will have a corresponding StrategiesVector, to store all the possible strategies of the node. Argument: - node(Node): node to build corresponding strategies_vector. - in_nodes(List[Node]): input nodes in the argument list of the node. - following_nodes(List[Node]): the nodes take the target node as their argument. - strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node. + node (Node): node for which the list of sharding strategies are generated. ''' - def __init__(self, node, in_nodes, following_nodes=None, strategies=None): + def __init__(self, node: Node): + super().__init__() self.node = node - self.in_nodes = in_nodes - self.following_nodes = following_nodes - - if strategies is None: - strategies = [] - self.strategies = strategies + # fetch its input and output nodes + self.predecessor_nodes = list(node._input_nodes.keys()) + self.successor_ndoes = list(node.users.keys()) def check_merge(self): pass diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 33a4a6a60..13ec9a16f 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -47,7 +47,9 @@ def test_conv_handler(): # [x, mul, conv, output] nodes = [node for node in gm.graph.nodes] - strategies_for_input = [] + # find the sharding strategies for the input node of the conv node + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(nodes[1]) sharding_option = (None, 0, 1) for first_sharding_index in sharding_option: for second_sharding_index in sharding_option: @@ -68,28 +70,19 @@ def test_conv_handler(): sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=entire_shape, sharding_sequence=sharding_sequence) - strategies_for_input.append(sharding_spec) - - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(node=nodes[0], - in_nodes=[nodes[1], 2], - strategies=strategies_for_input) + strategies_vector_for_input.append(sharding_spec) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ - nodes[1], - ]) - conv_handler = ConvHandler(input_node=nodes[1], - input_index=0, - weight=dict(gm.named_modules())[nodes[2].name].weight, - output_node=nodes[2], + # generate conv strategy + strategies_vector = StrategiesVector(node=nodes[2]) + conv_handler = ConvHandler(node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) - conv_handler.register_strategy_into_strategies_vector() + conv_handler.register_strategy() # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] - strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies] + strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector] # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py index 8afed7dd6..f85546b15 100644 --- a/tests/test_auto_parallel/test_dot_handler.py +++ b/tests/test_auto_parallel/test_dot_handler.py @@ -47,7 +47,9 @@ def test_dot_handler(): # [x, mul, linear, output] nodes = [node for node in gm.graph.nodes] - strategies_for_input = [] + # find the sharding strategies for the input node of the conv node + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(node=nodes[1]) sharding_option = (None, 0, 1) for first_sharding_index in sharding_option: for second_sharding_index in sharding_option: @@ -67,26 +69,19 @@ def test_dot_handler(): sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=entire_shape, sharding_sequence=sharding_sequence) - strategies_for_input.append(sharding_spec) - - # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] - strategies_vector_for_input = StrategiesVector(node=nodes[1], in_nodes=nodes[0], strategies=strategies_for_input) + strategies_vector_for_input.append(sharding_spec) setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) - strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ - nodes[1], - ]) - dot_handler = DotHandler(input_node=nodes[1], - input_index=0, - weight=dict(gm.named_modules())[nodes[2].name].weight, - output_node=nodes[2], + # generate dot strategy + strategies_vector = StrategiesVector(node=nodes[2]) + dot_handler = DotHandler(node=nodes[2], device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) - dot_handler.register_strategy_into_strategies_vector() + strategies_vector = dot_handler.register_strategy() # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] - strategy_name_list = [strategy.name for strategy in dot_handler.strategies_vector.strategies] + strategy_name_list = [strategy.name for strategy in strategies_vector] # SS = SR x RS assert 'S0S1 = S0R x RS1' in strategy_name_list