mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
[autoparallel] fix conv handler numerical test (#1771)
This commit is contained in:
parent
1e88811c7a
commit
27de252334
@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
|
||||||
|
else:
|
||||||
|
other_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["other"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
arg_index=1)
|
||||||
|
|
||||||
communication_action_mapping["other"] = other_comm_action
|
communication_action_mapping["other"] = other_comm_action
|
||||||
|
|
||||||
if self.has_bias and self.is_param("bias"):
|
if self.has_bias:
|
||||||
|
if self.is_param('bias'):
|
||||||
bias_comm_action = self.get_communication_action(
|
bias_comm_action = self.get_communication_action(
|
||||||
sharding_spec_mapping["bias"],
|
sharding_spec_mapping["bias"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
else:
|
||||||
|
bias_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
key_for_kwarg='bias')
|
||||||
communication_action_mapping["bias"] = bias_comm_action
|
communication_action_mapping["bias"] = bias_comm_action
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
|
||||||
|
else:
|
||||||
|
other_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["other"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
arg_index=1)
|
||||||
|
|
||||||
communication_action_mapping["other"] = other_comm_action
|
communication_action_mapping["other"] = other_comm_action
|
||||||
|
|
||||||
if self.has_bias and self.is_param("bias"):
|
if self.has_bias:
|
||||||
|
if self.is_param('bias'):
|
||||||
bias_comm_action = self.get_communication_action(
|
bias_comm_action = self.get_communication_action(
|
||||||
sharding_spec_mapping["bias"],
|
sharding_spec_mapping["bias"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
else:
|
||||||
|
bias_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
key_for_kwarg='bias')
|
||||||
communication_action_mapping["bias"] = bias_comm_action
|
communication_action_mapping["bias"] = bias_comm_action
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
communication_action_mapping["other"] = other_comm_action
|
|
||||||
|
|
||||||
if self.has_bias and self.is_param("bias"):
|
else:
|
||||||
|
other_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["other"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
arg_index=1)
|
||||||
|
communication_action_mapping["other"] = other_comm_action
|
||||||
|
if self.has_bias:
|
||||||
|
if self.is_param("bias"):
|
||||||
bias_comm_action = self.get_communication_action(
|
bias_comm_action = self.get_communication_action(
|
||||||
sharding_spec_mapping["bias"],
|
sharding_spec_mapping["bias"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_0,
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
else:
|
||||||
|
bias_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=mesh_dim_0,
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
key_for_kwarg='bias')
|
||||||
communication_action_mapping["bias"] = bias_comm_action
|
communication_action_mapping["bias"] = bias_comm_action
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||||||
input_comm_action = self.get_communication_action(
|
input_comm_action = self.get_communication_action(
|
||||||
sharding_spec_mapping["input"],
|
sharding_spec_mapping["input"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=mesh_dim_0,
|
logical_process_axis=mesh_dim_1,
|
||||||
comm_type=CommType.BEFORE,
|
comm_type=CommType.BEFORE,
|
||||||
arg_index=0)
|
arg_index=0)
|
||||||
|
|
||||||
@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
else:
|
||||||
|
other_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["other"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
arg_index=1)
|
||||||
|
|
||||||
communication_action_mapping["other"] = other_comm_action
|
communication_action_mapping["other"] = other_comm_action
|
||||||
|
|
||||||
if self.has_bias and self.is_param("bias"):
|
if self.has_bias:
|
||||||
|
if self.is_param("bias"):
|
||||||
bias_comm_action = self.get_communication_action(
|
bias_comm_action = self.get_communication_action(
|
||||||
sharding_spec_mapping["bias"],
|
sharding_spec_mapping["bias"],
|
||||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||||
comm_type=CommType.HOOK)
|
comm_type=CommType.HOOK)
|
||||||
|
else:
|
||||||
|
bias_comm_action = self.get_communication_action(
|
||||||
|
sharding_spec_mapping["bias"],
|
||||||
|
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||||
|
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||||
|
comm_type=CommType.BEFORE,
|
||||||
|
key_for_kwarg='bias')
|
||||||
communication_action_mapping["bias"] = bias_comm_action
|
communication_action_mapping["bias"] = bias_comm_action
|
||||||
|
|
||||||
return self.get_sharding_strategy(name=name,
|
return self.get_sharding_strategy(name=name,
|
||||||
|
@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
|||||||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("some cases need to be fixed")
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
# We temporarily ban the bias option before doing bias add
|
# We temporarily ban the bias option before doing bias add
|
||||||
@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False):
|
|||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip("some cases need to be fixed")
|
|
||||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
# We temporarily ban the bias option before doing bias add
|
# We temporarily ban the bias option before doing bias add
|
||||||
|
Loading…
Reference in New Issue
Block a user