[autoparallel]add backward cost info into strategies (#1524)

This commit is contained in:
YuliangLiu0306
2022-09-07 11:19:00 +08:00
committed by GitHub
parent 1a3599410d
commit 0908d0fc61
3 changed files with 142 additions and 63 deletions

View File

@@ -49,11 +49,59 @@ class ConvHandler(OperatorHandler):
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size = self.output_data.shape[2:]
output_size_product = reduce(operator.mul, output_size, 1)
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = self.weight.shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_backward_weight(int): The backward weight will be divided
into sharding_size_backward_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward = numel_output * size_per_elem_bytes / sharding_size_forward
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_backward_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward, memory_cost_backward_activation
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
@@ -76,14 +124,19 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost of this strategy
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@@ -115,13 +168,13 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce operation in both forward and backward phase.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
@@ -154,14 +207,18 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_1)
# This strategy do not need to do all_reduce operation during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@@ -193,14 +250,17 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@@ -232,13 +292,18 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward, mesh_dim_0)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@@ -270,15 +335,17 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_backward_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_backward_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
@@ -310,12 +377,13 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
@@ -349,13 +417,14 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_backward_weight)
# This strategy do not need to do all_reduce operation
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
@@ -390,13 +459,19 @@ class ConvHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel = self.output_data.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_backward_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_backward_weight)
# compute communication cost
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward, 0)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,

View File

@@ -85,12 +85,17 @@ class OperatorHandler(ABC):
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
_, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# compute the resharding cost during forward phase
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
target_spec, input_sharding_spec)
resharding_cost = resharding_cost_forward + resharding_cost_backward
resharding_costs[input_node].append(resharding_cost)
return resharding_costs