mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 01:28:31 +00:00
[autoparallel] adapt solver with self attention (#2037)
* [autoparallel] adapt solver with self attention * polish code
This commit is contained in:
@@ -154,12 +154,16 @@ class Solver:
|
||||
if self.forward_only:
|
||||
origin_communication_cost = communication_cost_item.fwd
|
||||
compute_cost = compute_cost_item.fwd
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.fwd
|
||||
else:
|
||||
origin_communication_cost = communication_cost_item.total
|
||||
compute_cost = compute_cost_item.total
|
||||
# extract MemoryCost item from the memory TrainCycleItem
|
||||
memory_cost = memory_cost_item.total
|
||||
|
||||
# extract the memory cost in float from MemoryCost item and sum them up
|
||||
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
|
||||
compute_costs.append(compute_cost)
|
||||
# node in extra_node_costs means it has some extra communication
|
||||
# cost from node merging, so we need to add those extra communication
|
||||
@@ -366,6 +370,8 @@ class Solver:
|
||||
for liveness_stage in liveness_set:
|
||||
mem = 0
|
||||
for live_variable in liveness_stage.unique_live_vars:
|
||||
if live_variable.node not in self.node_index_dict:
|
||||
continue
|
||||
node_index = self.node_index_dict[live_variable.node]
|
||||
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
|
||||
prob += mem <= memory_budget
|
||||
|
Reference in New Issue
Block a user