[autoparallel] complete gpt related module search (#2097)

This commit is contained in:
YuliangLiu0306
2022-12-08 10:04:09 +08:00
committed by GitHub
parent 85efb7ac2e
commit 3af7e65dea
3 changed files with 173 additions and 53 deletions

View File

@@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
last_physical_output_dims = output_op_data.data.dim() - 1
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=input_sharding_spec,
dim_mapping={last_logical_input_dims: last_physical_input_dims},
physical_shape=input_op_data.data.shape,
inplace=True,
)
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
else:
input_last_dim_mapping = {}
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
update_partition_dim(
sharding_spec=output_sharding_spec,
dim_mapping={last_logical_output_dims: last_physical_output_dims},
physical_shape=output_op_data.data.shape,
inplace=True,
)
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
else:
output_last_dim_mapping = {}
# get logger for debug message
logger = get_dist_logger()
@@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
input_dim_mapping = {0: i}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {0: i}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
strategy_copy.name = f'{strategy.name}_{i}'
@@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
input_dim_mapping = {}
input_dim_mapping.update(input_last_dim_mapping)
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
dim_mapping=input_dim_mapping,
physical_shape=input_op_data.data.shape,
inplace=True)
output_dim_mapping = {}
output_dim_mapping.update(output_last_dim_mapping)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
dim_mapping=output_dim_mapping,
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)