[autoparallel] adapt runtime passes (#1703)

* [autoparallel] adapt runtime passes v2

* polish code
This commit is contained in:
YuliangLiu0306
2022-10-14 10:14:07 +08:00
committed by GitHub
parent 21962e1593
commit 451cd72dea
5 changed files with 204 additions and 6 deletions

View File

@@ -58,9 +58,6 @@ class CostGraph:
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
if strategies_vector[i].resharding_costs is None:
print(strategies_vector.node.name)
assert False
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
if self.forward_only:
edge_cost[(j, i)] = resharding_cost_item.fwd

View File

@@ -90,8 +90,8 @@ class NodeHandler(ABC):
# compute the resharding costs based on the previous node
# strategies if specified
if compute_resharding_cost:
updated_strategies = map(self.update_resharding_cost, strategies)
strategies = list(updated_strategies)
updated_strategies = map(self.update_resharding_cost, post_processed_strategies)
post_processed_strategies = list(updated_strategies)
self.strategies_vector.extend(post_processed_strategies)

View File

@@ -52,7 +52,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = {