mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[autoparallel] adapt runtime passes (#1703)
* [autoparallel] adapt runtime passes v2 * polish code
This commit is contained in:
@@ -58,9 +58,6 @@ class CostGraph:
|
||||
edge_cost = {}
|
||||
for i in range(len(strategies_vector)):
|
||||
for j in range(len(src_node.strategies_vector)):
|
||||
if strategies_vector[i].resharding_costs is None:
|
||||
print(strategies_vector.node.name)
|
||||
assert False
|
||||
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
|
||||
if self.forward_only:
|
||||
edge_cost[(j, i)] = resharding_cost_item.fwd
|
||||
|
@@ -90,8 +90,8 @@ class NodeHandler(ABC):
|
||||
# compute the resharding costs based on the previous node
|
||||
# strategies if specified
|
||||
if compute_resharding_cost:
|
||||
updated_strategies = map(self.update_resharding_cost, strategies)
|
||||
strategies = list(updated_strategies)
|
||||
updated_strategies = map(self.update_resharding_cost, post_processed_strategies)
|
||||
post_processed_strategies = list(updated_strategies)
|
||||
|
||||
self.strategies_vector.extend(post_processed_strategies)
|
||||
|
||||
|
@@ -52,7 +52,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
|
||||
total_compute_cost = forward_compute_cost + backward_compute_cost
|
||||
|
||||
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
|
||||
return compute_cost
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
forward_size_mapping = {
|
||||
|
Reference in New Issue
Block a user