mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
This commit is contained in:
@@ -63,14 +63,40 @@ class CostGraph:
|
||||
edge_cost[(j, i)] = resharding_cost_item.total
|
||||
self.edge_costs[node_pair] = edge_cost
|
||||
# add parents and children attribute to node
|
||||
parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
|
||||
# children_nodes = [node for node in strategies_vector.successor_nodes]
|
||||
parent_nodes = []
|
||||
children_nodes = []
|
||||
|
||||
def _check_tensor_in_node(data):
|
||||
"""
|
||||
This method is used to check whether the data has a tensor inside or not.
|
||||
"""
|
||||
has_tensor_flag = False
|
||||
if isinstance(data, torch.Tensor):
|
||||
return True
|
||||
elif isinstance(data, (tuple, list)):
|
||||
for d in data:
|
||||
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
|
||||
return has_tensor_flag
|
||||
|
||||
for node in strategies_vector.predecessor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
parent_nodes.append(node)
|
||||
for node in strategies_vector.successor_nodes:
|
||||
if _check_tensor_in_node(node._meta_data):
|
||||
children_nodes.append(node)
|
||||
|
||||
setattr(dst_node, 'parents', parent_nodes)
|
||||
setattr(dst_node, 'children', children_nodes)
|
||||
|
||||
if self.simplify and strategies_vector.check_merge():
|
||||
for followed_node in strategies_vector.predecessor_nodes:
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
# we only merge node pairs which src node has a tensor element inside.
|
||||
# This is necessay because the node without a tensor element inside will not
|
||||
# be assigned any strategy.
|
||||
if _check_tensor_in_node(followed_node._meta_data):
|
||||
self.merge_pair.append((followed_node, dst_node))
|
||||
|
||||
def get_edge_cost(self, src_node, dst_node):
|
||||
return self.edge_costs[(src_node, dst_node)]
|
||||
|
@@ -154,12 +154,16 @@ class Solver:
|
||||
if self.forward_only:
|
||||
origin_communication_cost = communication_cost_item.fwd
|
||||
compute_cost = compute_cost_item.fwd
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.fwd
|
||||
else:
|
||||
origin_communication_cost = communication_cost_item.total
|
||||
compute_cost = compute_cost_item.total
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.total
|
||||
|
||||
# extract the memory cost in float from MemoryCost item and sum them up
|
||||
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
|
||||
compute_costs.append(compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
@@ -366,6 +370,8 @@ class Solver:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
if live_variable.node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
Reference in New Issue
Block a user