mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 05:33:23 +00:00
[autoparallel] distinguish different parallel strategies (#2699)
This commit is contained in:
@@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(
|
||||
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
|
||||
LinearProjectionStrategyGenerator(op_data_mapping,
|
||||
self.device_mesh,
|
||||
linear_projection_type='linear',
|
||||
solver_perference=self.solver_perference))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
|
@@ -3,6 +3,7 @@ from ast import arg
|
||||
from functools import reduce
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.options import SolverPerference
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommType,
|
||||
MemoryCost,
|
||||
@@ -209,9 +210,14 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
|
||||
|
||||
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||
|
||||
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
|
||||
def __init__(self,
|
||||
operation_data_mapping,
|
||||
device_mesh,
|
||||
linear_projection_type='linear',
|
||||
solver_perference=SolverPerference.STANDARD):
|
||||
super().__init__(operation_data_mapping, device_mesh)
|
||||
self.linear_projection_type = linear_projection_type
|
||||
self.solver_perference = solver_perference
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
|
||||
# C = AB
|
||||
@@ -231,16 +237,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||
total=fwd_compute_cost + bwd_compute_cost)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
def dp_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
|
||||
# SS = SR x RS
|
||||
strategies.append(self.split_lhs_space_rhs_space(0, 1))
|
||||
strategies.append(self.split_lhs_space_rhs_space(1, 0))
|
||||
# S01R = S01R x RR
|
||||
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
||||
|
||||
# SR = SS x SR
|
||||
strategies.append(self.split_lhs_space_both_contract(0, 1))
|
||||
strategies.append(self.split_lhs_space_both_contract(1, 0))
|
||||
return strategies
|
||||
|
||||
def tp_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
|
||||
# RR = RS01 x S01R
|
||||
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
|
||||
|
||||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||
|
||||
# RS = RS x SS
|
||||
strategies.append(self.split_rhs_space_both_contract(0, 1))
|
||||
@@ -254,20 +266,38 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||
strategies.append(self.split_rhs_space_only(0))
|
||||
strategies.append(self.split_rhs_space_only(1))
|
||||
|
||||
# S01R = S01R x RR
|
||||
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
|
||||
return strategies
|
||||
|
||||
# RR = RS01 x S01R
|
||||
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
|
||||
def mix_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
|
||||
# RS01 = RR x RS01
|
||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||
# SS = SR x RS
|
||||
strategies.append(self.split_lhs_space_rhs_space(0, 1))
|
||||
strategies.append(self.split_lhs_space_rhs_space(1, 0))
|
||||
|
||||
# SR = SS x SR
|
||||
strategies.append(self.split_lhs_space_both_contract(0, 1))
|
||||
strategies.append(self.split_lhs_space_both_contract(1, 0))
|
||||
|
||||
# RR = RR x RR
|
||||
strategies.append(self.non_split())
|
||||
|
||||
return strategies
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategies = []
|
||||
|
||||
if self.solver_perference == SolverPerference.STANDARD:
|
||||
strategies.extend(self.dp_strategies())
|
||||
strategies.extend(self.tp_strategies())
|
||||
strategies.extend(self.mix_strategies())
|
||||
elif self.solver_perference == SolverPerference.DP:
|
||||
strategies.extend(self.dp_strategies())
|
||||
elif self.solver_perference == SolverPerference.TP:
|
||||
strategies.extend(self.tp_strategies())
|
||||
|
||||
return strategies
|
||||
|
||||
@ignore_sharding_exception
|
||||
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||
# handle case SS = SR x RS
|
||||
|
Reference in New Issue
Block a user