[autoparallel]add bcast matmul strategies (#1605)

This commit is contained in:
YuliangLiu0306
2022-09-20 11:26:21 +08:00
committed by GitHub
parent edb67cb378
commit 47b11c432c
5 changed files with 497 additions and 36 deletions

View File

@@ -14,7 +14,7 @@ ELEMENTWISE_FUNC_OP = [
RESHAPE_FUNC_OP = [torch.flatten, torch.Tensor.view, torch.reshape]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv
operator.mul, operator.floordiv, operator.truediv, torch.matmul
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,

View File

@@ -46,6 +46,15 @@ class BcastOpHandler(OperatorHandler):
return sharding_spec
def _generate_compute_cost(self, total_sharding_size):
lhs_matrix_shape = self.lhs_data.shape[-2:]
rhs_matrix_shape = self.rhs_data.shape[-2:]
batch_dimensions_shape = self.output_data.shape[:-2]
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
compute_cost = reduce(
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
return compute_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
@@ -88,44 +97,66 @@ class BcastOpHandler(OperatorHandler):
return resharding_costs
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
output_sharding_spec_list = []
output_dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(self.output_data.dim()):
for j in range(i + 1, self.output_data.dim()):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
output_dim_partition_list.append(dim_partition_dict_0)
output_dim_partition_list.append(dim_partition_dict_1)
# enumerate all the 1D sharding cases
for i in range(self.output_data.dim()):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_dict_1 = {i: [mesh_dim_1]}
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
output_dim_partition_list.append(dim_partition_dict_0)
output_dim_partition_list.append(dim_partition_dict_1)
output_dim_partition_list.append(dim_partition_dict_flatten)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
sharding_spec_list = []
check_duplicated_list = []
for output_dim_partition_dict in output_dim_partition_list:
for output_dim_partition_dict in dim_partition_list:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
output_sharding_spec_list.append(output_sharding_spec)
sharding_spec_list.append(output_sharding_spec)
return sharding_spec_list
def _enumerate_all_possible_2d_sharding(self, mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
return dim_partition_list
def _enumerate_all_possible_1d_sharding(self, mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
return dim_partition_list
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list = []
dim_size = self.output_data.dim()
# enumerate all the 2D sharding cases
sharding_list_2d = self._enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = self._enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = self._enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
return output_sharding_spec_list
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
@@ -158,7 +189,377 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
##############################################
#used to generate strategies for torch.matmul#
##############################################
# @exception_handler
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
batch_sharding_dims = []
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
# in this case, total_sharding_size is equal to the batch sharding size
memory_cost = self.output_data.numel() / batch_sharding_size
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(batch_sharding_size)
# in this case, no communication takes place.
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'i' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is fully replicated.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of output is sharded on 'i' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'k' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is fully replicated.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
batch_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data is fully replicated on matrix dimensions.
memory_cost = self.output_data.numel() / batch_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be is sharded on 'j' dimension.
# in this case, the matrix dimensions of lhs is fully replicated.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is sharded on 'j' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
self._split_dim_i(dim_partition_dict, mesh_dim_list)
self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list)
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward and backward activation computation.
communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_2d_strategies_for_matmul(self):
self._split_lhs_space_both_contract(0, 1)
self._split_lhs_space_both_contract(1, 0)
self._split_rhs_space_both_contract(0, 1)
self._split_rhs_space_both_contract(1, 0)
self._split_lhs_space_rhs_space(0, 1)
self._split_lhs_space_rhs_space(1, 0)
def register_strategy(self) -> StrategiesVector:
output_sharding_specs = self._enumerate_all_possible_output(0, 1)
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)
MESH_DIM_LIST = [0, 1]
if self.node.target != torch.matmul:
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)
else:
# we only care about the non-computing dimensions,
# therefore, we omit the last two dimensions.
dim_size = self.output_data.dim() - 2
# Both device mesh axises are uesd on batch dimensions
dim_partition_dicts_2d = self._enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1],
dim_size)
for dim_partition_dict in dim_partition_dicts_2d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
# Only one device mesh axis is uesd on batch dimensions
for mesh_dim_index in [0, 1]:
dim_partition_dicts_1d = self._enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index],
dim_size)
for dim_partition_dict in dim_partition_dicts_1d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
# No device mesh axis is uesd on batch dimensions
dim_partition_dict_on_batch_dim = {}
self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
self._registry_2d_strategies_for_matmul()

View File

@@ -11,7 +11,6 @@ from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
__all__ = ['DotHandler']
@@ -465,7 +464,7 @@ class DotHandler(OperatorHandler):
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}

View File

@@ -50,6 +50,15 @@ class StrategiesConstructor:
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
@@ -222,7 +231,7 @@ class StrategiesConstructor:
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP:
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)