diff --git a/colossalai/auto_parallel/solver/constants.py b/colossalai/auto_parallel/solver/constants.py index d9f06bf70..91c20d343 100644 --- a/colossalai/auto_parallel/solver/constants.py +++ b/colossalai/auto_parallel/solver/constants.py @@ -16,7 +16,6 @@ ELEMENTWISE_FUNC_OP = [ torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, - torch.flatten, # softmax should not be here torch.nn.functional.softmax ] diff --git a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py index b57b1e83d..53ff73a90 100644 --- a/colossalai/auto_parallel/solver/op_handler/reshape_handler.py +++ b/colossalai/auto_parallel/solver/op_handler/reshape_handler.py @@ -69,6 +69,7 @@ class ReshapeHandler(OperatorHandler): shape_consistency_manager = ShapeConsistencyManager() _, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec, replicate_input_sharding_spec) + communication_cost = communication_cost["total"] # generate resharding cost resharding_costs = self._generate_resharding_costs([input_sharding_spec]) diff --git a/colossalai/auto_parallel/solver/solver.py b/colossalai/auto_parallel/solver/solver.py index 8ca756c5e..6cd1e26c8 100644 --- a/colossalai/auto_parallel/solver/solver.py +++ b/colossalai/auto_parallel/solver/solver.py @@ -319,6 +319,8 @@ class Solver: obj = 0 for i in range(node_nums): assert len(s[i]) == len(c[i]) + assert len(s[i]) == len(d[i]) + obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i]) #############################################