diff --git a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py index 1f1d681e0..fb2e53dad 100644 --- a/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py @@ -40,6 +40,11 @@ class BcastOpHandler(OperatorHandler): for dim_index, _ in dim_partition_dict.items(): if shape[dim_index] == 1: processed_dim_partition_dict.pop(dim_index) + for dim_index, sharding_index_list in processed_dim_partition_dict.items(): + sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list] + sharding_size = reduce(operator.mul, sharding_list, 1) + assert shape[ + dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.' sharding_spec = ShardingSpec(device_mesh=self.device_mesh, entire_shape=shape, dim_partition_dict=processed_dim_partition_dict) @@ -83,14 +88,10 @@ class BcastOpHandler(OperatorHandler): entire_shape=new_entire_shape, dim_partition_dict=new_dim_partition_dict) - # compute the resharding cost during forward phase - _, _, resharding_cost_forward = shape_consistency_manager.shape_consistency( + # compute the resharding cost + _, _, total_resharding_cost = shape_consistency_manager.shape_consistency( input_sharding_spec, input_spec) - _, _, resharding_cost_backward = shape_consistency_manager.shape_consistency( - input_spec, input_sharding_spec) - total_resharding_cost = resharding_cost_forward + resharding_cost_backward - # we need multiply the size of elem dtype to get correct communication cost resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_costs[input_node].append(resharding_cost) @@ -102,7 +103,11 @@ class BcastOpHandler(OperatorHandler): sharding_spec_list = [] check_duplicated_list = [] for output_dim_partition_dict in dim_partition_list: - output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) + try: + output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict) + except AssertionError as e: + warnings.warn(f'{e}') + break sharding_seq = output_sharding_spec.sharding_sequence if sharding_seq not in check_duplicated_list: check_duplicated_list.append(sharding_seq) @@ -166,7 +171,7 @@ class BcastOpHandler(OperatorHandler): ############################################## #used to generate strategies for torch.matmul# ############################################## - # @exception_handler + @exception_handler def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim): # this dim partition dict only describes the batch dimensions, but in this scenario, # matrix dimensions are fully replicated, so it do not need extra process. @@ -205,6 +210,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -262,6 +268,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -325,6 +332,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. @@ -390,6 +398,7 @@ class BcastOpHandler(OperatorHandler): self._split_dim_k(dim_partition_dict, mesh_dim_list) self._split_dim_j(dim_partition_dict, mesh_dim_list) + @exception_handler def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) @@ -426,6 +435,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-1: [mesh_dim_0]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) @@ -464,6 +474,7 @@ class BcastOpHandler(OperatorHandler): self.strategies_vector.append(sharding_strategies) + @exception_handler def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): dim_partition_dict_for_lhs = {-2: [mesh_dim_0]} sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)