[autoparallel] resnet block runtime apply (#1709)

* [autoparallel] resnet block runtime apply

* seperate buffer and parameter in MemoryCost

* polish code

* add comments and todos

* fix test issue
This commit is contained in:
YuliangLiu0306
2022-10-17 13:37:38 +08:00
committed by GitHub
parent b0a23dc4fc
commit 845ff4a47a
11 changed files with 277 additions and 27 deletions

View File

@@ -1,4 +1,5 @@
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
import torch
class CostGraph:
@@ -51,7 +52,6 @@ class CostGraph:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
@@ -62,10 +62,12 @@ class CostGraph:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
children_nodes = [node for node in strategies_vector.successor_nodes]
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
# self._remove_invalid_node(dst_node, 'parents')
# self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes: