mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 17:40:33 +00:00
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
This commit is contained in:
@@ -63,14 +63,40 @@ class CostGraph:
|
||||
edge_cost[(j, i)] = resharding_cost_item.total
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
# children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
parent_nodes = []
|
||||
children_nodes = []
|
||||
|
||||
def _check_tensor_in_node(data):
|
||||
"""
|
||||
This method is used to check whether the data has a tensor inside or not.
|
||||
"""
|
||||
has_tensor_flag = False
|
||||
if isinstance(data, torch.Tensor):
|
||||
return True
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for d in data:
|
||||
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
|
||||
return has_tensor_flag
|
||||
|
||||
for node in strategies_vector.predecessor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
parent_nodes.append(node)
|
||||
for node in strategies_vector.successor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
children_nodes.append(node)
|
||||
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
# we only merge node pairs which src node has a tensor element inside.
|
||||
# This is necessay because the node without a tensor element inside will not
|
||||
# be assigned any strategy.
|
||||
if _check_tensor_in_node(followed_node._meta_data):
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
Reference in New Issue
Block a user