[autoparallel] fix linear logical convert issue (#1857)

This commit is contained in:
YuliangLiu0306 2022-11-10 17:19:22 +08:00 committed by GitHub
parent c2947dadf1
commit 1b494ad73c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 40 additions and 10 deletions

View File

@ -52,7 +52,6 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
if node.op == 'get_attr': if node.op == 'get_attr':
assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.' assert len(target_sharding_specs) == 1, f'sharing weight is not supported in current version.'
new_sharding_spec = target_sharding_specs[0] new_sharding_spec = target_sharding_specs[0]
user_node = node.strategies_vector.successor_nodes[0]
user_strategy = node.strategies_vector.successor_nodes[0].best_strategy user_strategy = node.strategies_vector.successor_nodes[0].best_strategy
op_data_in_user = user_strategy.get_op_data_by_name(str(node)) op_data_in_user = user_strategy.get_op_data_by_name(str(node))
origin_node_sharding_spec_dict[index] = new_sharding_spec origin_node_sharding_spec_dict[index] = new_sharding_spec

View File

@ -30,7 +30,8 @@ def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStr
op_data = strategy.get_op_data_by_name(weight_name) op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \ assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same" "Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
transpose_partition_dim(sharding_spec, 0, -1) dim_size = len(op_data.logical_shape)
transpose_partition_dim(sharding_spec, 0, dim_size - 1)
return strategy return strategy
@ -54,6 +55,29 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
input_op_data = strategy.get_op_data_by_name(input_name) input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name) output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name) input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy.get_sharding_spec_by_name(output_op_data.name)
# recover the last logical dimension to physical dimension
last_logical_input_dims = len(input_op_data.logical_shape) - 1
last_logical_output_dims = len(output_op_data.logical_shape) - 1
last_physical_input_dims = input_op_data.data.dim() - 1
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping={last_logical_input_dims: last_physical_input_dims},
physical_shape=input_op_data.data.shape,
inplace=True,
)
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping={last_logical_output_dims: last_physical_output_dims},
physical_shape=output_op_data.data.shape,
inplace=True,
)
# get logger for debug message # get logger for debug message
logger = get_dist_logger() logger = get_dist_logger()
@ -198,7 +222,14 @@ class LinearFunctionHandler(NodeHandler):
type=data_type, type=data_type,
data=self.node.args[1]._meta_data, data=self.node.args[1]._meta_data,
logical_shape=self.node.args[1]._meta_data.shape[::-1]) logical_shape=self.node.args[1]._meta_data.shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(
name=str(self.node),
type=OperationDataType.OUTPUT,
data=self.node._meta_data,
logical_shape=output_logical_shape,
)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
@ -219,7 +250,6 @@ class LinearFunctionHandler(NodeHandler):
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1])) weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input

View File

@ -32,7 +32,8 @@ class Solver:
memory_budget: float = -1.0, memory_budget: float = -1.0,
solution_numbers: int = 1, solution_numbers: int = 1,
forward_only: bool = False, forward_only: bool = False,
memory_increasing_coefficient: float = 1.3): memory_increasing_coefficient: float = 1.3,
verbose=True):
''' '''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph. Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument: Argument:
@ -64,6 +65,7 @@ class Solver:
self.last_s_val = None self.last_s_val = None
# The last objective value of the best ILP solution. # The last objective value of the best ILP solution.
self.last_objective = None self.last_objective = None
self.verbose = verbose
def _recover_merged_node_strategy(self): def _recover_merged_node_strategy(self):
''' '''
@ -177,7 +179,7 @@ class Solver:
# omit initial value for nodes # omit initial value for nodes
s_init_np = None s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self, def _call_solver_serialized_args(self,
node_nums, node_nums,
@ -192,7 +194,8 @@ class Solver:
memory_costs, memory_costs,
resharding_costs, resharding_costs,
alias_convert_costs, alias_convert_costs,
s_init_np=None): s_init_np=None,
verbose=True):
""" """
Call the solver with serialized arguments. Call the solver with serialized arguments.
""" """
@ -407,8 +410,6 @@ class Solver:
# if v[idx][row * C + col] > 0.5: # if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1 # prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose msg = verbose
time_limit = 600 time_limit = 600
assert "COIN_CMD" in pulp.listSolvers( assert "COIN_CMD" in pulp.listSolvers(

View File

@ -95,7 +95,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
cost_graph = CostGraph(strategies_constructor.leaf_strategies) cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph() cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm) graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, verbose=False)
ret = solver.call_solver_serialized_args() ret = solver.call_solver_serialized_args()
solution = list(ret[0]) solution = list(ret[0])
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass( gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(