mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[autoparallel] resnet block runtime apply (#1709)
* [autoparallel] resnet block runtime apply * seperate buffer and parameter in MemoryCost * polish code * add comments and todos * fix test issue
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
|
||||
import torch
|
||||
|
||||
|
||||
class CostGraph:
|
||||
@@ -51,7 +52,6 @@ class CostGraph:
|
||||
if src_node not in self.nodes:
|
||||
continue
|
||||
node_pair = (src_node, dst_node)
|
||||
# src_index = strategies_vector.predecessor_nodes.index(src_node)
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
@@ -62,10 +62,12 @@ class CostGraph:
|
||||
edge_cost[(j, i)] = resharding_cost_item.total
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
|
||||
setattr(dst_node, 'children', strategies_vector.successor_nodes)
|
||||
self._remove_invalid_node(dst_node, 'parents')
|
||||
self._remove_invalid_node(dst_node, 'children')
|
||||
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
# self._remove_invalid_node(dst_node, 'parents')
|
||||
# self._remove_invalid_node(dst_node, 'children')
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
|
Reference in New Issue
Block a user