mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 06:00:07 +00:00
[autoparallel] handled illegal sharding strategy in shape consistency (#1744)
* [autoparallel] handled illegal sharding strategy in shape consistency * polish code
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Dict, List
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
|
||||
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy
|
||||
from .node_handler import ModuleHandler, NodeHandler
|
||||
from .registry import operator_registry
|
||||
from .strategy import ConvStrategyGenerator, StrategyGenerator
|
||||
@@ -68,7 +68,7 @@ class ConvModuleHandler(ModuleHandler):
|
||||
dim_partition_dict[1] = second_dim_partition
|
||||
|
||||
# re-init the sharding spec
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
|
||||
sharding_spec.__init__(sharding_spec.device_mesh, op_data.data.shape, dim_partition_dict)
|
||||
return strategy
|
||||
|
||||
|
||||
|
@@ -46,6 +46,7 @@ class NodeHandler(ABC):
|
||||
# TODO: test this function when other handlers are ready
|
||||
resharding_costs = {}
|
||||
shape_consistency_manager = ShapeConsistencyManager()
|
||||
|
||||
for node in self.predecessor_node:
|
||||
node_name = str(node)
|
||||
|
||||
@@ -54,7 +55,9 @@ class NodeHandler(ABC):
|
||||
assert hasattr(node, 'strategies_vector'), \
|
||||
f'The predecessor node {node_name} has no strategy vector to compute the resharding cost.'
|
||||
prev_strategy_vector = node.strategies_vector
|
||||
prev_sharding_specs = [strategy.get_sharding_spec_by_name(node_name) for strategy in prev_strategy_vector]
|
||||
prev_sharding_specs = [
|
||||
prev_strategy.get_sharding_spec_by_name(node_name) for prev_strategy in prev_strategy_vector
|
||||
]
|
||||
|
||||
# get the current sharding spec generated by this node handler
|
||||
op_data = strategy.get_op_data_by_name(node_name)
|
||||
|
Reference in New Issue
Block a user