diff --git a/colossalai/auto_parallel/solver/conv_handler.py b/colossalai/auto_parallel/solver/conv_handler.py index 6526e1018..3cbe43926 100644 --- a/colossalai/auto_parallel/solver/conv_handler.py +++ b/colossalai/auto_parallel/solver/conv_handler.py @@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler): # 3D: (H * W * D) * N * Cout * Cin * kernel output_size = self.output_data.shape[2:] output_size_product = reduce(operator.mul, output_size, 1) + input_size = self.input_data.shape[2:] + input_size_product = reduce(operator.mul, input_size, 1) kernel_size = self.weight.shape[2:] kernel_size_product = reduce(operator.mul, kernel_size, 1) - compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product + backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product + compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost return compute_cost + def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, + sharding_size_backward_weight): + ''' + Compute the memory cost per device with this specific strategy. + + Argument: + sharding_size_forward(int): The forward activation will be divided + into sharding_size_forward number partions. + sharding_size_backward_activation(int): The backward activation will + be divided into sharding_size_backward_activation number partions. + sharding_size_backward_weight(int): The backward weight will be divided + into sharding_size_backward_weight number partions. + + Return: + memory_cost(Tuple[float]): Memory cost per device with this + specific strategy, the first element of this tuple is forward + memory cost, and the second element of this tuple is backward + memory cost. + memory_cost_forward(float): Memory cost of forward activation per + device with this specific strategy. + memory_cost_backward_activation(float): Memory cost of backward activation + per device with this specific strategy. + ''' + # compute the memory cost of this strategy + dtype = self.input_data.dtype + numel_output = self.output_data.numel() + numel_input = self.input_data.numel() + numel_weight = self.weight.numel() + size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() + + # forward memory_cost + memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward + + # backward memory_cost + memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation + memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight + memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight + + # memory_cost pair + memory_cost = (memory_cost_forward, memory_cost_backward) + + return memory_cost, memory_cost_forward, memory_cost_backward_activation + def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' @@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost = numel * size_per_elem_bytes / sharding_size + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) + + # This strategy do not need to do all_reduce operation during forward + communication_cost_forward = 0 + # compute the backward communication cost of this strategy + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) + # total communication cost + communication_cost = communication_cost_forward + communication_cost_backward - # This strategy do not need to do all_reduce operation - communication_cost = 0 sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_weight = 1 + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_backward_weight) - # This strategy do not need to do all_reduce operation + # This strategy do not need to do all_reduce operation in both forward and backward phase. communication_cost = 0 sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, @@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_backward_weight) - # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1) + # This strategy do not need to do all_reduce operation during backward phase + communication_cost_backward = 0 + communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size + sharding_size_forward = self.device_mesh.shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] + memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) - # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0) + # compute the communication cost of this strategy during backward phase + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1) + communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - memory_cost = numel * size_per_elem_bytes + sharding_size_forward = 1 + sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_backward_weight) - # compute the communication cost of this strategy - communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0) + # compute the communication cost of this strategy during forward phase + communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0) + # This strategy do NOT need all_reduce during forward phase + communication_cost_backward = 0 + communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] - memory_cost = numel * size_per_elem_bytes / sharding_size - - # This strategy do not need to do all_reduce operation - communication_cost = 0 + sharding_size_forward = self.device_mesh.shape[mesh_dim_0] + sharding_size_backward_activation = 1 + sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] + memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost( + sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight) + # This strategy do not need to do all_reduce during forward phase + communication_cost_forward = 0 + # compute the communication cost of this strategy during backward phase + communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0) + communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, compute_cost=compute_cost, @@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - memory_cost = numel * size_per_elem_bytes + sharding_size_forward = 1 + sharding_size_backward_activation = 1 + sharding_size_backward_weight = 1 + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_backward_weight) - # This strategy do not need to do all_reduce operation + # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 sharding_strategies = ShardingStrategy(name, @@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1] - memory_cost = numel * size_per_elem_bytes / sharding_size + sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] + sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ + mesh_dim_1] + sharding_size_backward_weight = 1 + memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation, + sharding_size_backward_weight) - # This strategy do not need to do all_reduce operation + # This strategy do not need to do all_reduce in both forward and backward phase communication_cost = 0 sharding_strategies = ShardingStrategy(name, @@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler): compute_cost = self._generate_compute_cost(bs, channel_in, channel_out) # compute the memory cost of this strategy - dtype = self.input_data.dtype - numel = self.output_data.numel() - size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() - memory_cost = numel * size_per_elem_bytes + sharding_size_forward = 1 + sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[ + mesh_dim_1] + sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1] + memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward, + sharding_size_backward_activation, + sharding_size_backward_weight) - # compute communication cost - communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0) + # compute communication cost during forward phase + communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0) + # This strategy do NOT need do all_reduce during backward phase + communication_cost_backward = 0 + communication_cost = communication_cost_forward + communication_cost_backward sharding_strategies = ShardingStrategy(name, output_sharding_spec=sharding_spec_for_ouput, diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py index 5c4cc7def..62289b5ce 100644 --- a/colossalai/auto_parallel/solver/operator_handler.py +++ b/colossalai/auto_parallel/solver/operator_handler.py @@ -85,12 +85,17 @@ class OperatorHandler(ABC): ''' # The resharding_cost of weight is counted due to sharing weight cases. resharding_costs = {} - for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input): + for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input): resharding_costs[input_node] = [] for strategy in input_node.strategies_vector: input_sharding_spec = strategy.output_sharding_spec assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' - _, _, resharding_cost = self.shape_consistency_manager.shape_consistency( - input_sharding_spec, input_spec) + # compute the resharding cost during forward phase + _, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency( + input_sharding_spec, target_spec) + # In backward phase, we should convert grad with target_spec into input_sharding_spec + _, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency( + target_spec, input_sharding_spec) + resharding_cost = resharding_cost_forward + resharding_cost_backward resharding_costs[input_node].append(resharding_cost) return resharding_costs diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py index 52b8ba28a..45eb87e3b 100644 --- a/tests/test_auto_parallel/test_conv_handler.py +++ b/tests/test_auto_parallel/test_conv_handler.py @@ -82,7 +82,6 @@ def test_conv_handler(): strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager) conv_handler.register_strategy() - # ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'] strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector]