[autoparallel] handled illegal sharding strategy in shape consistency (#1744)

* [autoparallel] handled illegal sharding strategy in shape consistency

* polish code
This commit is contained in:
Frank Lee
2022-10-20 12:06:25 +08:00
committed by GitHub
parent 88a79814fb
commit 993b8875b6
4 changed files with 109 additions and 89 deletions

View File

@@ -3,7 +3,7 @@ from typing import Dict, List
import torch
import torch.nn.functional as F
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
from .node_handler import ModuleHandler, NodeHandler
from .registry import operator_registry
from .strategy import ConvStrategyGenerator, StrategyGenerator
@@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler):
dim_partition_dict[1] = second_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
return strategy

View File

@@ -46,6 +46,7 @@ class NodeHandler(ABC):
# TODO: test this function when other handlers are ready
resharding_costs = {}
shape_consistency_manager = ShapeConsistencyManager()
for node in self.predecessor_node:
node_name = str(node)
@@ -54,7 +55,9 @@ class NodeHandler(ABC):
assert hasattr(node, 'strategies_vector'), \
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
prev_strategy_vector = node.strategies_vector
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
prev_sharding_specs = [
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
]
# get the current sharding spec generated by this node handler
op_data = strategy.get_op_data_by_name(node_name)