diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 228471870..1a816eb13 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -1,9 +1,7 @@ -from lib2to3.pytree import Base import operator from functools import reduce import torch -from colossalai.tensor.sharding_spec import ShardingSpec -from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy from .operator_handler import OperatorHanlder @@ -26,25 +24,6 @@ class ConvHandler(OperatorHanlder): assert self.input_data.dim() in (3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' - def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): - ''' - Compute the resharding costs with this specific strategy. - - Note: The resharding_cost of weight is NOT counted. - - Argument: - resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary. - Resharding_cost[i][j] means the cost of i-th argument in the output node argument list - with j-th strategy in its strategies_vector transforms to sharding spec wanted in this - strategy. - sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. - ''' - # The resharding_cost of weight is counted due to sharing weight cases. - resharding_costs[self.input_index] = [] - for stategy in self.input_node.strategies_vector.strategies: - _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input) - resharding_costs[self.input_index].append(resharding_cost) - def _generate_compute_cost(self, bs, channel_in, channel_out): ''' Compute the computation cost per device with this specific strategy. diff --git a/colossalai/auto_parallel/solver/dot_handler.py b/colossalai/auto_parallel/solver/dot_handler.py index a25466ea3..db99beb21 100644 --- a/colossalai/auto_parallel/solver/dot_handler.py +++ b/colossalai/auto_parallel/solver/dot_handler.py @@ -1,4 +1,8 @@ +import operator +import torch +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy from .operator_handler import OperatorHanlder +from functools import reduce class DotHandler(OperatorHanlder): @@ -6,7 +10,226 @@ class DotHandler(OperatorHanlder): A OperatorHandler which deals with the sharding strategies of linear matrix multiplication. """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def _generate_compute_cost(self, input_shape, weight_shape): + # TODO: consider bias addition + compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 + return compute_cost - # TODO: refactor the dot handler in my local branch to align with the latest main branch + def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): + # handle case SS = SR x RS + name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' + + dim_partition_dict_for_input = {0: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + # linear layer weight is transposed during init + dim_partition_dict_for_weight = {0: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute computation cost + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost + # no all-reduce required for this case + communication_cost = 0 + + # create and register strategy + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + # handle the case SR = SS x SR + name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' + + dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + # since weight of the linear layer is transposed + # the actual dim to be sharded is 1 + dim_partition_dict_for_weight = {1: [mesh_dim_0]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {0: [mesh_dim_0]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): + name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' + + dim_partition_dict_for_input = {1: [mesh_dim_0]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim_1]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_input) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim_0] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def recompute_split_both_contract(self, mesh_dim): + name = f'RR = RS{mesh_dim} x S{mesh_dim}R' + + dim_partition_dict_for_input = {1: [mesh_dim]} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {1: [mesh_dim]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + memory_cost = numel * size_per_elem_bytes + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def split_rhs_space_only(self, mesh_dim): + name = f'RS{mesh_dim} = RR x RS{mesh_dim}' + + dim_partition_dict_for_input = {} + sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input) + + dim_partition_dict_for_weight = {0: [mesh_dim]} + sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight) + + dim_partition_dict_for_output = {1: [mesh_dim]} + sharding_spec_for_ouput = self._generate_sharding_spec(self.output, dim_partition_dict_for_output) + + # generate resharding cost for this strategy + resharding_costs = {} + self._generate_resharding_costs(resharding_costs, sharding_spec_for_input) + + # compute the computation cost of this strategy + compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) + + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel = self.output.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + sharding_size = self.device_mesh.shape[mesh_dim] + memory_cost = numel * size_per_elem_bytes / sharding_size + + # compute the communication cost of this strategy + communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) + sharding_strategies = ShardingStrategy(name, + output_sharding_spec=sharding_spec_for_ouput, + compute_cost=compute_cost, + communication_cost=communication_cost, + memory_cost=memory_cost, + resharding_costs=resharding_costs, + input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) + self.strategies_vector.strategies.append(sharding_strategies) + + def register_strategy_into_strategies_vector(self): + ''' + Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector. + + Output: + + ''' + # SS = SR x RS + self.split_lhs_space_rhs_space(0, 1) + self.split_lhs_space_rhs_space(1, 0) + + # SR = SS x SR + self.split_lhs_space_both_contract(0, 1) + self.split_lhs_space_both_contract(1, 0) + + # RS = RS x SS + self.split_rhs_space_both_contract(0, 1) + self.split_rhs_space_both_contract(1, 0) + + # RR= RS x SR + self.recompute_split_both_contract(0) + self.recompute_split_both_contract(1) + + # RS = RR x RS + self.split_rhs_space_only(0) + self.split_rhs_space_only(1) diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 24027e996..1331b306f 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -43,3 +43,23 @@ class OperatorHanlder(ABC): entire_shape=tensor.shape, dim_partition_dict=dim_partition_dict) return sharding_spec + + def _generate_resharding_costs(self, resharding_costs, sharding_spec_for_input): + ''' + Compute the resharding costs with this specific strategy. + + Note: The resharding_cost of weight is NOT counted. + + Argument: + resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary. + Resharding_cost[i][j] means the cost of i-th argument in the output node argument list + with j-th strategy in its strategies_vector transforms to sharding spec wanted in this + strategy. + sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node. + ''' + # The resharding_cost of weight is counted due to sharing weight cases. + resharding_costs[self.input_index] = [] + for stategy in self.input_node.strategies_vector.strategies: + _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(stategy, sharding_spec_for_input) + resharding_costs[self.input_index].append(resharding_cost) + return resharding_cost diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py index 025a13fc6..6c465d0c7 100644 --- a/colossalai/auto_parallel/solver/sharding_strategy.py +++ b/colossalai/auto_parallel/solver/sharding_strategy.py @@ -42,10 +42,13 @@ class StrategiesVector: strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node. ''' - def __init__(self, node, in_nodes, following_nodes=None, strategies=[]): + def __init__(self, node, in_nodes, following_nodes=None, strategies=None): self.node = node self.in_nodes = in_nodes self.following_nodes = following_nodes + + if strategies is None: + strategies = [] self.strategies = strategies def check_merge(self): diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py new file mode 100644 index 000000000..8afed7dd6 --- /dev/null +++ b/tests/test_auto_parallel/test_dot_handler.py @@ -0,0 +1,113 @@ +import torch +from torch.fx import GraphModule +import torch.nn as nn +import pytest + +from colossalai.fx.proxy import ColoProxy +from colossalai.fx.tracer.tracer import ColoTracer +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.auto_parallel.solver.dot_handler import DotHandler +from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.device.device_mesh import DeviceMesh + + +class LinearModel(nn.Module): + + def __init__(self, in_features, out_features): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + + def forward(self, x): + x = x * 2 + x = self.linear(x) + return x + + +def test_dot_handler(): + physical_mesh_id = torch.arange(0, 4) + mesh_shape = (2, 2) + # [[0, 1] + # [2, 3]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 8)) + shape_consistency_manager = ShapeConsistencyManager() + + tracer = ColoTracer() + model = LinearModel(8, 16) + input_sample = {'x': torch.rand(4, 8).to('meta')} + # graph(): + # %x : torch.Tensor [#users=1] = placeholder[target=x] + # %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {}) + # %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {}) + # return conv + graph = tracer.trace(root=model, meta_args=input_sample) + gm = GraphModule(model, graph, model.__class__.__name__) + gm.recompile() + # [x, mul, linear, output] + nodes = [node for node in gm.graph.nodes] + + strategies_for_input = [] + sharding_option = (None, 0, 1) + for first_sharding_index in sharding_option: + for second_sharding_index in sharding_option: + if first_sharding_index is not None and second_sharding_index == first_sharding_index: + continue + if first_sharding_index is None: + first_dim_spec = _DimSpec([]) + else: + first_dim_spec = _DimSpec([first_sharding_index]) + + if second_sharding_index is None: + second_dim_spec = _DimSpec([]) + else: + second_dim_spec = _DimSpec([second_sharding_index]) + + sharding_sequence = [first_dim_spec, second_dim_spec] + sharding_spec = ShardingSpec(device_mesh=device_mesh, + entire_shape=entire_shape, + sharding_sequence=sharding_sequence) + strategies_for_input.append(sharding_spec) + + # strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]] + strategies_vector_for_input = StrategiesVector(node=nodes[1], in_nodes=nodes[0], strategies=strategies_for_input) + setattr(nodes[1], 'strategies_vector', strategies_vector_for_input) + + strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[ + nodes[1], + ]) + dot_handler = DotHandler(input_node=nodes[1], + input_index=0, + weight=dict(gm.named_modules())[nodes[2].name].weight, + output_node=nodes[2], + device_mesh=device_mesh, + strategies_vector=strategies_vector, + shape_consistency_manager=shape_consistency_manager) + dot_handler.register_strategy_into_strategies_vector() + + # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR'] + strategy_name_list = [strategy.name for strategy in dot_handler.strategies_vector.strategies] + + # SS = SR x RS + assert 'S0S1 = S0R x RS1' in strategy_name_list + assert 'S1S0 = S1R x RS0' in strategy_name_list + + # SR = SS x SR + assert 'S0R = S0S1 x S1R' in strategy_name_list + assert 'S1R = S1S0 x S0R' in strategy_name_list + + # RS = RS x SS + assert 'RS0 = RS1 x S1S0' in strategy_name_list + assert 'RS1 = RS0 x S0S1' in strategy_name_list + + # RR = RS x SR + assert 'RR = RS0 x S0R' in strategy_name_list + assert 'RR = RS1 x S1R' in strategy_name_list + + # RS= RR x RS + assert 'RS0 = RR x RS0' in strategy_name_list + assert 'RS1 = RR x RS1' in strategy_name_list + + +if __name__ == '__main__': + test_dot_handler()