[NFC] polish colossalai/auto_parallel/tensor_shard/deprecated/cost_graph.py code style (#2720)

Co-authored-by: Fuzhao Xue <fuzhao@login2.ls6.tacc.utexas.edu>
This commit is contained in:
Xue Fuzhao 2023-02-15 16:12:45 +08:00 committed by GitHub
parent 51c45c2460
commit e81caeb4bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
from typing import List
import math import math
from typing import List
from torch.fx.node import Node from torch.fx.node import Node
from .constants import INFINITY_COST from .constants import INFINITY_COST
@ -9,7 +11,7 @@ class CostGraph:
A graph data structure to simplify the edge cost graph. It has two main functions: A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in 1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list. CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as 2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes. be given by the StrategiesVector depending on the type of target node and following nodes.
@ -75,14 +77,14 @@ class CostGraph:
def merge_node(self, src_node, dst_node): def merge_node(self, src_node, dst_node):
''' '''
To merge dst_node into src_node, we need to do it in following steps: To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy 1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1 strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)] into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0. x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs 2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy, contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy. another is the origin extra costs in src_node strategy.