mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-18 01:12:42 +00:00
[autoparallel] added all non-bcast matmul strategies (#1603)
This commit is contained in:
parent
db98b695b2
commit
3abf98a633
@ -13,9 +13,238 @@ from typing import List
|
|||||||
__all__ = ['DotHandler']
|
__all__ = ['DotHandler']
|
||||||
|
|
||||||
|
|
||||||
|
class DotProductStrategyGenerator(StrategyGenerator):
|
||||||
|
"""
|
||||||
|
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
|
||||||
|
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
|
||||||
|
do not consider bias here.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def validate(self, input, other):
|
||||||
|
assert input.dim() == 1 and other.dim() == 1
|
||||||
|
|
||||||
|
def no_split(self):
|
||||||
|
name = f'R = R dot R'
|
||||||
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_one_dim(self, mesh_dim):
|
||||||
|
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
|
||||||
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||||
|
|
||||||
|
def generate(self) -> List[IntermediateStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# do not split dimensions for dot product
|
||||||
|
# R = R dot R
|
||||||
|
strategy_list.append(self.no_split())
|
||||||
|
|
||||||
|
# split two tensors in the same dimensions
|
||||||
|
# S = S dot S
|
||||||
|
strategy_list.append(self.split_one_dim(0))
|
||||||
|
strategy_list.append(self.split_one_dim(1))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
|
class MatVecStrategyGenerator(StrategyGenerator):
|
||||||
|
|
||||||
|
def validate(self, input, other) -> bool:
|
||||||
|
assert input.dim() > 1 and other.dim() == 1
|
||||||
|
|
||||||
|
def no_split(self):
|
||||||
|
name = "R = R x R"
|
||||||
|
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_input_batch(self, mesh_dim):
|
||||||
|
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
|
||||||
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def generate(self) -> List[IntermediateStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# no split
|
||||||
|
strategy_list.append(self.no_split())
|
||||||
|
|
||||||
|
# split the batch dim for the first tensor only
|
||||||
|
strategy_list.append(self.split_input_batch(0))
|
||||||
|
strategy_list.append(self.split_input_batch(1))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
class MatMulStrategyGenerator(StrategyGenerator):
|
class MatMulStrategyGenerator(StrategyGenerator):
|
||||||
# TODO: to be implmented
|
"""
|
||||||
pass
|
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
|
||||||
|
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
|
||||||
|
|
||||||
|
A matmul can be formulated as [n, p] x [p, q] = [n, q]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
|
||||||
|
This will incur extra transformation of the dim partitioning as the weight is transposed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, is_linear: bool, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_linear = is_linear
|
||||||
|
|
||||||
|
# as the weight for the linear module is transposed, we can compute
|
||||||
|
# the correponding dimension indexfor convenience
|
||||||
|
if is_linear:
|
||||||
|
self.dim_q = 0
|
||||||
|
self.dim_p = 1
|
||||||
|
else:
|
||||||
|
self.dim_q = 1
|
||||||
|
self.dim_p = 0
|
||||||
|
|
||||||
|
def validate(self, input, other, bias) -> bool:
|
||||||
|
# make sure the second tensor is a 2D tensor
|
||||||
|
assert input.dim() > 0 and other.dim() == 2
|
||||||
|
|
||||||
|
# make sure bias is of the same dimension
|
||||||
|
if self.is_linear:
|
||||||
|
assert bias is None or bias.shape[-1] == other.shape[0]
|
||||||
|
else:
|
||||||
|
assert bias is None or bias.shape[-1] == other.shape[1]
|
||||||
|
|
||||||
|
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle case SS = SR x RS
|
||||||
|
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
|
||||||
|
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
# handle the case SR = SS x SR
|
||||||
|
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||||
|
|
||||||
|
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_0],
|
||||||
|
self.dim_q: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def recompute_split_both_contract(self, mesh_dim):
|
||||||
|
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||||
|
|
||||||
|
def split_rhs_space_only(self, mesh_dim):
|
||||||
|
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
|
||||||
|
|
||||||
|
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
self.dim_p: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name,
|
||||||
|
dim_partition_dict=dim_partition_dict,
|
||||||
|
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
|
||||||
|
|
||||||
|
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {},
|
||||||
|
"other": {
|
||||||
|
self.dim_q: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
-1: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
|
||||||
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||||
@ -30,6 +259,15 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
|||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.is_torch_bmm = is_torch_bmm
|
self.is_torch_bmm = is_torch_bmm
|
||||||
|
|
||||||
|
def validate(self, input, other, bias) -> bool:
|
||||||
|
if self.is_torch_bmm:
|
||||||
|
assert input.shape == other.shape
|
||||||
|
assert input.dim() > 2
|
||||||
|
assert other.shape[-1] == bias.shape[0]
|
||||||
|
else:
|
||||||
|
# TODO: validate these inputs are broadcastable
|
||||||
|
pass
|
||||||
|
|
||||||
def split_one_batch_dim(self):
|
def split_one_batch_dim(self):
|
||||||
if 1 in self.device_mesh.mesh_shape:
|
if 1 in self.device_mesh.mesh_shape:
|
||||||
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||||
|
@ -32,4 +32,14 @@ class StrategyGenerator(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def generate(self) -> List[IntermediateStrategy]:
|
def generate(self) -> List[IntermediateStrategy]:
|
||||||
pass
|
"""
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def validate(self, *args, **kwargs) -> bool:
|
||||||
|
"""
|
||||||
|
Validate if the operands are of desired shape.
|
||||||
|
If True, means this generator can be used for the current operation.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user