[autoparallel] change the merge node logic (#1533)

This commit is contained in:
YuliangLiu0306
2022-09-07 11:18:19 +08:00
committed by GitHub
parent ae71036cd2
commit 44c866a3e3
3 changed files with 71 additions and 43 deletions

View File

@@ -1,5 +1,6 @@
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from typing import List
import math
from torch.fx.node import Node
@@ -23,6 +24,7 @@ class CostGraph:
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self._build_cost_graph()
@@ -50,15 +52,15 @@ class CostGraph:
setattr(dst_node, 'children', strategies_vector.successor_nodes)
if self.simplify and strategies_vector.check_merge():
for following_node in strategies_vector.successor_nodes:
self.merge_pair.append((dst_node, following_node))
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge src_node into dst_node, we need to do it in following steps:
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
@@ -81,52 +83,76 @@ class CostGraph:
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
resharding_costs = strategy.resharding_costs
resharding_cost_for_src = resharding_costs[src_node]
lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
merge_map[dst_strate_index] = lowest_cost_index
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = math.inf
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost < min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for dst node
self.extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
target_strate_index = merge_map[dst_strate_index]
self.extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node][
target_strate_index]
if src_node in self.extra_node_costs:
self.extra_node_costs[dst_node][dst_strate_index] += self.extra_node_costs[src_node][
target_strate_index]
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for parent_node in src_node.parents:
new_node_pair = (parent_node, dst_node)
old_node_pair = (parent_node, src_node)
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[dst_node]):
for j in range(self.node_lens[parent_node]):
src_strate_index = merge_map[i]
edge_cost[(j, i)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
self.edge_costs[new_node_pair] = edge_cost
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect dst node and parents of src node
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for parent_node in src_node.parents:
if parent_node not in dst_node.parents:
dst_node.parents.append(parent_node)
if dst_node not in parent_node.children:
parent_node.children.append(dst_node)
# remove src node from cost graph when src node has no consumer.
if len(src_node.children) == 0:
parent_node.children.remove(src_node)
node_pair = (parent_node, src_node)
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict