[autoparallel] distinguish different parallel strategies (#2699)

This commit is contained in:
YuliangLiu0306
2023-02-15 22:28:28 +08:00
committed by GitHub
parent ae86a29e23
commit 1dc003c169
7 changed files with 255 additions and 219 deletions

View File

@@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:

View File

@@ -3,6 +3,7 @@ from ast import arg
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.options import SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
@@ -209,9 +210,14 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
@@ -231,16 +237,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def collate_strategies(self) -> List[ShardingStrategy]:
def dp_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
return strategies
def tp_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1))
@@ -254,20 +266,38 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
return strategies
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
def mix_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# SS = SR x RS
strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RR = RR x RR
strategies.append(self.non_split())
return strategies
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
if self.solver_perference == SolverPerference.STANDARD:
strategies.extend(self.dp_strategies())
strategies.extend(self.tp_strategies())
strategies.extend(self.mix_strategies())
elif self.solver_perference == SolverPerference.DP:
strategies.extend(self.dp_strategies())
elif self.solver_perference == SolverPerference.TP:
strategies.extend(self.tp_strategies())
return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS