[autoparallel] add resnet autoparallel unit test and add backward weight communication cost (#1589)

This commit is contained in:
YuliangLiu0306
2022-09-13 18:05:05 +08:00
committed by GitHub
parent 7c18a588c8
commit d164449d00
2 changed files with 168 additions and 29 deletions

View File

@@ -103,7 +103,7 @@ class ConvHandler(OperatorHandler):
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@@ -132,15 +132,18 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost of this strategy
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
# compute the backward communication cost to all reduce the input activation grad
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
mesh_dim_1)
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
@@ -178,11 +181,16 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation in both forward and backward phase.
communication_cost = 0
# This strategy do not need to do all_reduce operation in forward phase.
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute the total cost
communication_cost = communication_cost_forward + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
@@ -220,15 +228,17 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
# This strategy do not need to do all_reduce operation to compute the input activation grad
communication_cost_backward_activation = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute total cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
@@ -265,7 +275,7 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
@@ -309,9 +319,8 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
@@ -354,7 +363,7 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
@@ -398,8 +407,8 @@ class ConvHandler(OperatorHandler):
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
@@ -441,11 +450,17 @@ class ConvHandler(OperatorHandler):
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
# This strategy do not need to do all_reduce in forward phase
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
# compute the total communication cost
communication_cost = communication_cost_backward_weight + communication_cost_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
@@ -485,9 +500,8 @@ class ConvHandler(OperatorHandler):
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(