mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] complete gpt related module search (#2097)
This commit is contained in:
@@ -64,20 +64,14 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
||||
last_physical_output_dims = output_op_data.data.dim() - 1
|
||||
|
||||
if last_logical_input_dims in input_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=input_sharding_spec,
|
||||
dim_mapping={last_logical_input_dims: last_physical_input_dims},
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
input_last_dim_mapping = {last_logical_input_dims: last_physical_input_dims}
|
||||
else:
|
||||
input_last_dim_mapping = {}
|
||||
|
||||
if last_logical_output_dims in output_sharding_spec.dim_partition_dict:
|
||||
update_partition_dim(
|
||||
sharding_spec=output_sharding_spec,
|
||||
dim_mapping={last_logical_output_dims: last_physical_output_dims},
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True,
|
||||
)
|
||||
output_last_dim_mapping = {last_logical_output_dims: last_physical_output_dims}
|
||||
else:
|
||||
output_last_dim_mapping = {}
|
||||
|
||||
# get logger for debug message
|
||||
logger = get_dist_logger()
|
||||
@@ -97,12 +91,18 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
try:
|
||||
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
|
||||
input_dim_mapping = {0: i}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
output_dim_mapping = {0: i}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={0: i},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
strategy_copy.name = f'{strategy.name}_{i}'
|
||||
@@ -120,12 +120,17 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
|
||||
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
|
||||
|
||||
# after updating, the logical shape will be replaced by the physical shape
|
||||
input_dim_mapping = {}
|
||||
input_dim_mapping.update(input_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=input_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=input_dim_mapping,
|
||||
physical_shape=input_op_data.data.shape,
|
||||
inplace=True)
|
||||
|
||||
output_dim_mapping = {}
|
||||
output_dim_mapping.update(output_last_dim_mapping)
|
||||
update_partition_dim(sharding_spec=output_sharding_spec,
|
||||
dim_mapping={},
|
||||
dim_mapping=output_dim_mapping,
|
||||
physical_shape=output_op_data.data.shape,
|
||||
inplace=True)
|
||||
sharding_strategies.append(strategy_copy)
|
||||
|
Reference in New Issue
Block a user