mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[autoparallel] add resnet autoparallel unit test and add backward weight communication cost (#1589)
This commit is contained in:
@@ -103,7 +103,7 @@ class ConvHandler(OperatorHandler):
|
||||
# memory_cost pair
|
||||
memory_cost = (memory_cost_forward, memory_cost_backward)
|
||||
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
|
||||
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
|
||||
|
||||
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
|
||||
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||
@@ -132,15 +132,18 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation during forward
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost of this strategy
|
||||
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
|
||||
# compute the backward communication cost to all reduce the input activation grad
|
||||
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
|
||||
mesh_dim_1)
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# total communication cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
@@ -178,11 +181,16 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce operation in both forward and backward phase.
|
||||
communication_cost = 0
|
||||
# This strategy do not need to do all_reduce operation in forward phase.
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute the total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
@@ -220,15 +228,17 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
|
||||
# This strategy do not need to do all_reduce operation during backward phase
|
||||
communication_cost_backward = 0
|
||||
communication_cost = communication_cost_forward + communication_cost_backward
|
||||
# This strategy do not need to do all_reduce operation to compute the input activation grad
|
||||
communication_cost_backward_activation = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
|
||||
# compute total cost
|
||||
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
compute_cost=compute_cost,
|
||||
@@ -265,7 +275,7 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
@@ -309,9 +319,8 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute the communication cost of this strategy during forward phase
|
||||
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
|
||||
@@ -354,7 +363,7 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
|
||||
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
|
||||
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce during forward phase
|
||||
@@ -398,8 +407,8 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_forward = 1
|
||||
sharding_size_backward_activation = 1
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
@@ -441,11 +450,17 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = 1
|
||||
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
|
||||
# This strategy do not need to do all_reduce in both forward and backward phase
|
||||
communication_cost = 0
|
||||
# This strategy do not need to do all_reduce in forward phase
|
||||
communication_cost_forward = 0
|
||||
# compute the backward communication cost to all reduce the weight due to data parallel
|
||||
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
memory_cost_backward_weight, 0)
|
||||
# compute the total communication cost
|
||||
communication_cost = communication_cost_backward_weight + communication_cost_forward
|
||||
|
||||
sharding_strategies = ShardingStrategy(name,
|
||||
output_sharding_spec=sharding_spec_for_output,
|
||||
@@ -485,9 +500,8 @@ class ConvHandler(OperatorHandler):
|
||||
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
|
||||
mesh_dim_1]
|
||||
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
|
||||
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
|
||||
sharding_size_backward_activation,
|
||||
sharding_size_weight)
|
||||
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
|
||||
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
|
||||
|
||||
# compute communication cost during forward phase
|
||||
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
|
||||
|
Reference in New Issue
Block a user