[autoparallel] adapt autoparallel tests with latest api (#2626)

This commit is contained in:
YuliangLiu0306
2023-02-08 15:02:12 +08:00
committed by GitHub
parent c375563653
commit cb3d1bef62
8 changed files with 59 additions and 583 deletions

View File

@@ -247,12 +247,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies.append(self.split_rhs_space_both_contract(1, 0))
# RR= RS x SR
# strategies.append(self.recompute_split_both_contract(0))
# strategies.append(self.recompute_split_both_contract(1))
strategies.append(self.recompute_split_both_contract(0))
strategies.append(self.recompute_split_both_contract(1))
# # RS = RR x RS
# strategies.append(self.split_rhs_space_only(0))
# strategies.append(self.split_rhs_space_only(1))
# RS = RR x RS
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
@@ -263,8 +263,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# # RR = RR x RR
# strategies.append(self.non_split())
# RR = RR x RR
strategies.append(self.non_split())
return strategies

View File

@@ -62,9 +62,6 @@ class CostGraph:
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []