[autoparallel] use pytree map style to process data (#1989)

This commit is contained in:
YuliangLiu0306
2022-11-21 10:44:22 +08:00
committed by GitHub
parent 35e6b9ec82
commit 155891113e
7 changed files with 178 additions and 66 deletions

View File

@@ -1,13 +1,14 @@
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import torch
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
@@ -66,8 +67,6 @@ class CostGraph:
children_nodes = [node for node in strategies_vector.successor_nodes]
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
# self._remove_invalid_node(dst_node, 'parents')
# self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
@@ -79,14 +78,14 @@ class CostGraph:
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
@@ -98,10 +97,9 @@ class CostGraph:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
for src_index, _ in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
@@ -139,7 +137,6 @@ class CostGraph:
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost

View File

@@ -1,3 +1,4 @@
import builtins
import math
import operator
from copy import deepcopy
@@ -13,6 +14,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh
from .options import DataloaderOption, SolverOptions
@@ -49,10 +51,6 @@ class StrategiesConstructor:
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy is None:
print(strategies_vector.node.name)
print(strategies_vector)
assert False
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
@@ -64,10 +62,33 @@ class StrategiesConstructor:
"""
This method is to build the strategy vector for each node in the computation graph.
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
return False
def _check_no_strategy_for_data(data):
label = True
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (tuple, list)):
for d in data:
label = label and _check_no_strategy_for_data(d)
return label
return _check_no_strategy_for_data(node._meta_data)
no_strategy_node = []
for node in self.nodes:
strategies_vector = StrategiesVector(node)
print(node)
if _check_no_strategy_for_node(node):
no_strategy_node.append(node)
pass
# placeholder node
if node.op == 'placeholder':
elif node.op == 'placeholder':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
else:
@@ -80,7 +101,7 @@ class StrategiesConstructor:
placeholder_handler.register_strategy()
# get_attr node
if node.op == 'get_attr':
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler.register_strategy()
@@ -114,10 +135,19 @@ class StrategiesConstructor:
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
if len(strategies_vector) <= 0:
print(node.name)
assert len(strategies_vector) > 0
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)