mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 15:11:20 +00:00
[autoparallel] remove redundancy comm node (#1893)
This commit is contained in:
@@ -74,11 +74,13 @@ class NodeHandler(ABC):
|
||||
if op_data.type == OperationDataType.PARAM:
|
||||
resharding_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
|
||||
else:
|
||||
dtype = op_data.data.dtype
|
||||
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
|
||||
_, _, resharding_cost = shape_consistency_manager.shape_consistency(
|
||||
prev_sharding_spec, current_sharding_spec)
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"],
|
||||
bwd=resharding_cost["backward"],
|
||||
total=resharding_cost["total"])
|
||||
resharding_cost = TrainCycleItem(fwd=resharding_cost["forward"] * size_per_elem_bytes,
|
||||
bwd=resharding_cost["backward"] * size_per_elem_bytes,
|
||||
total=resharding_cost["total"] * size_per_elem_bytes)
|
||||
resharding_costs[node].append(resharding_cost)
|
||||
strategy.resharding_costs = resharding_costs
|
||||
return strategy
|
||||
|
@@ -218,7 +218,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.AFTER)
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
@@ -254,7 +254,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.AFTER)
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
@@ -300,7 +300,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
sharding_spec=sharding_spec_mapping["output"],
|
||||
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
|
||||
logical_process_axis=[mesh_dim_0],
|
||||
comm_type=CommType.AFTER)
|
||||
comm_type=CommType.IMPLICIT)
|
||||
|
||||
communication_action_mapping = {"output": output_comm_action}
|
||||
|
||||
@@ -331,14 +331,14 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
||||
# TODO: The strategies below should be uncommented after runtime
|
||||
# passes ready.
|
||||
# SR = SR x R WITH SYNC_BN
|
||||
# strategy_list.append(self.split_input_batch(0))
|
||||
# strategy_list.append(self.split_input_batch(1))
|
||||
strategy_list.append(self.split_input_batch(0))
|
||||
strategy_list.append(self.split_input_batch(1))
|
||||
|
||||
# SS = SS x S WITH SYNC_BN
|
||||
# strategy_list.append(self.split_input_both_dim(0, 1))
|
||||
# strategy_list.append(self.split_input_both_dim(1, 0))
|
||||
strategy_list.append(self.split_input_both_dim(0, 1))
|
||||
strategy_list.append(self.split_input_both_dim(1, 0))
|
||||
|
||||
# S01R = S01R x R WITH SYNC_BN
|
||||
# strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
strategy_list.append(self.split_input_batch_1d(0, 1))
|
||||
|
||||
return strategy_list
|
||||
|
Reference in New Issue
Block a user