mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-24 19:17:30 +00:00
[autoparallel] update getitem handler (#2207)
This commit is contained in:
@@ -223,7 +223,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
|
||||
node.args = new_args
|
||||
|
||||
elif isinstance(getitem_index, (tuple, list)):
|
||||
assert isinstance(getitem_index[0], slice)
|
||||
if not isinstance(getitem_index[0], slice):
|
||||
continue
|
||||
new_slice_items = []
|
||||
|
||||
for slice_item in getitem_index:
|
||||
|
@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
|
||||
|
||||
|
||||
@operator_registry.register(BCAST_FUNC_OP)
|
||||
class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
||||
class BinaryElementwiseHandler(NodeHandler):
|
||||
"""
|
||||
An BinaryBcastOpHandler is a node handler which deals with operations which have two
|
||||
operands and broadcasting occurs such as torch.add.
|
||||
|
@@ -7,7 +7,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.tensor.sharding_spec import ShardingSpecException
|
||||
|
||||
from .strategy_generator import FollowingStrategyGenerator
|
||||
|
||||
@@ -69,39 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
getitem_index = self.op_data['index'].data
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
dim_partition_dict_for_input = strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict
|
||||
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
|
||||
gather_input = 0 in dim_partition_dict_for_input
|
||||
if gather_input:
|
||||
logical_process_axis = dim_partition_dict_for_output.pop(0)
|
||||
try:
|
||||
logger = get_dist_logger()
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
dim_partition_dict_for_input = copy.deepcopy(
|
||||
strategy.output_sharding_specs[self.op_data["input"]].dim_partition_dict)
|
||||
|
||||
shift_dim_partition_dict_for_output = {}
|
||||
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
|
||||
shift_dim_partition_dict_for_output[dim - 1] = mesh_dim_list
|
||||
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
if gather_input:
|
||||
input_communication_action = self.get_communication_action(
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=logical_process_axis,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
communication_action_mapping["input"] = input_communication_action
|
||||
int_index = False
|
||||
if isinstance(getitem_index, int):
|
||||
int_index = True
|
||||
getitem_dims = [
|
||||
0,
|
||||
]
|
||||
shift_length = 1
|
||||
elif isinstance(getitem_index, slice):
|
||||
getitem_dims = [
|
||||
0,
|
||||
]
|
||||
else:
|
||||
getitem_dims = [i for i in range(len(getitem_index))]
|
||||
if isinstance(getitem_index[0], int):
|
||||
int_index = True
|
||||
shift_length = len(getitem_index)
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
gather_dims = []
|
||||
for dim in getitem_dims:
|
||||
if dim in dim_partition_dict_for_input:
|
||||
gather_dims.append(dim)
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
for dim in gather_dims:
|
||||
dim_partition_dict_for_input.pop(dim)
|
||||
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
|
||||
|
||||
if int_index:
|
||||
shift_dim_partition_dict_for_output = {}
|
||||
for dim, mesh_dim_list in dim_partition_dict_for_output.items():
|
||||
shift_dim_partition_dict_for_output[dim - shift_length] = mesh_dim_list
|
||||
dim_partition_dict_for_output = shift_dim_partition_dict_for_output
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}_{index}'
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
except ShardingSpecException as e:
|
||||
logger.debug(e)
|
||||
continue
|
||||
strategy_list.append(strategy)
|
||||
|
||||
for strategy in strategy_list:
|
||||
|
Reference in New Issue
Block a user