mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[autoparallel] use pytree map style to process data (#1989)
This commit is contained in:
@@ -1,13 +1,14 @@
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
import torch
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
|
||||
|
||||
class CostGraph:
|
||||
'''
|
||||
A graph data structure to simplify the edge cost graph. It has two main functions:
|
||||
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
|
||||
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
2. To reduce the searching space, we merge computationally-trivial operators, such as
|
||||
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
|
||||
be given by the StrategiesVector depending on the type of target node and following nodes.
|
||||
|
||||
@@ -66,8 +67,6 @@ class CostGraph:
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
# self._remove_invalid_node(dst_node, 'parents')
|
||||
# self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
@@ -79,14 +78,14 @@ class CostGraph:
|
||||
def merge_node(self, src_node, dst_node):
|
||||
'''
|
||||
To merge dst_node into src_node, we need to do it in following steps:
|
||||
|
||||
|
||||
1. For each strategy in dst_node, we need to pick an appropriate strategy
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
of src_node to merge, it is important because the logical resharding costs
|
||||
between the parents node of src_node and merged node depend on the src_node
|
||||
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
|
||||
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
|
||||
x represents the picking strategy of node 1 merged into node 2 strategy 0.
|
||||
|
||||
|
||||
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
|
||||
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
|
||||
another is the origin extra costs in src_node strategy.
|
||||
@@ -98,10 +97,9 @@ class CostGraph:
|
||||
src_node(Node): The node will be merged into dst_node.
|
||||
dst_node(Node): The node to integrate src_node.
|
||||
'''
|
||||
src_node_index = dst_node.parents.index(src_node)
|
||||
# build merge_map
|
||||
merge_map = {}
|
||||
for src_index, strategy in enumerate(src_node.strategies_vector):
|
||||
for src_index, _ in enumerate(src_node.strategies_vector):
|
||||
min_cost = INFINITY_COST
|
||||
lowest_cost_index = -1
|
||||
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
|
||||
@@ -139,7 +137,6 @@ class CostGraph:
|
||||
for i in range(self.node_lens[src_node]):
|
||||
for j in range(self.node_lens[child_node]):
|
||||
dst_strate_index = merge_map[i]
|
||||
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
|
||||
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
|
||||
if new_node_pair not in self.edge_costs:
|
||||
self.edge_costs[new_node_pair] = edge_cost
|
||||
|
Reference in New Issue
Block a user