[autoparallel] adapt solver with self attention (#2037)

* [autoparallel] adapt solver with self attention

* polish code
This commit is contained in:
YuliangLiu0306
2022-12-01 17:53:15 +08:00
committed by GitHub
parent d3499c98d4
commit 1c1fe44305
6 changed files with 320 additions and 13 deletions

View File

@@ -63,14 +63,40 @@ class CostGraph:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
children_nodes = [node for node in strategies_vector.successor_nodes]
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []
def _check_tensor_in_node(data):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
return has_tensor_flag
for node in strategies_vector.predecessor_nodes:
if _check_tensor_in_node(node._meta_data):
parent_nodes.append(node)
for node in strategies_vector.successor_nodes:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]

View File

@@ -154,12 +154,16 @@ class Solver:
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
@@ -366,6 +370,8 @@ class Solver:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget