mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
[autoparallel] fix conv handler numerical test (#1771)
This commit is contained in:
@@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
@@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
@@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
if self.has_bias:
|
||||
if self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
@@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
|
||||
@@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
|
Reference in New Issue
Block a user