mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-01 09:42:35 +00:00
[autoparallel] added strategy generator and bmm strategies (#1602)
This commit is contained in:
parent
a19eb80998
commit
db98b695b2
@ -1,15 +1,168 @@
|
|||||||
import operator
|
import operator
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
|
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from enum import Enum
|
||||||
|
from .strategy_generator import StrategyGenerator, IntermediateStrategy
|
||||||
|
from typing import List
|
||||||
|
|
||||||
__all__ = ['DotHandler']
|
__all__ = ['DotHandler']
|
||||||
|
|
||||||
|
|
||||||
|
class MatMulStrategyGenerator(StrategyGenerator):
|
||||||
|
# TODO: to be implmented
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BatchedMatMulStrategyGenerator(StrategyGenerator):
|
||||||
|
"""
|
||||||
|
Generate sharding strategies for the batched matrix multiplication.
|
||||||
|
|
||||||
|
A batched matrix multiplication can be viewed as
|
||||||
|
[b, i, k] x [b, k, j] -> [b, i, j]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.is_torch_bmm = is_torch_bmm
|
||||||
|
|
||||||
|
def split_one_batch_dim(self):
|
||||||
|
if 1 in self.device_mesh.mesh_shape:
|
||||||
|
mesh_dim = self.device_mesh.mesh_shape.index(1)
|
||||||
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0, mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_one_batch_dim(self, mesh_dim):
|
||||||
|
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
|
||||||
|
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: mesh_dim_0,
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
|
||||||
|
|
||||||
|
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
|
||||||
|
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
|
||||||
|
dim_partition_dict = {
|
||||||
|
"input": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-1: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"other": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
},
|
||||||
|
"bias": {},
|
||||||
|
"output": {
|
||||||
|
0: [mesh_dim_0],
|
||||||
|
-2: [mesh_dim_1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
|
||||||
|
|
||||||
|
def generate(self) -> List[IntermediateStrategy]:
|
||||||
|
strategy_list = []
|
||||||
|
|
||||||
|
# split only the batch dimension
|
||||||
|
# Sb = Sb x Sb
|
||||||
|
# can be None as it is only for 1D device mesh
|
||||||
|
strategy = self.split_one_batch_dim()
|
||||||
|
if strategy:
|
||||||
|
strategy_list.append(strategy)
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the i dim of the first tensor
|
||||||
|
# SbSi = SbSi x Sb
|
||||||
|
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the j of the second tensor
|
||||||
|
# SbSj = Sb x SbSj
|
||||||
|
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
|
||||||
|
|
||||||
|
# split batch dim of two inputs and the k dim of two inputs
|
||||||
|
# Sb = SbSk x SbSk, need to all-reduce by k dim
|
||||||
|
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
|
||||||
|
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
|
||||||
|
|
||||||
|
# split two batch dim
|
||||||
|
strategy_list.append(self.split_two_batch_dim(0, 1))
|
||||||
|
strategy_list.append(self.split_two_batch_dim(1, 0))
|
||||||
|
|
||||||
|
return strategy_list
|
||||||
|
|
||||||
|
|
||||||
class DotHandler(OperatorHandler):
|
class DotHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
|
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
@ -297,7 +450,7 @@ class DotHandler(OperatorHandler):
|
|||||||
|
|
||||||
def register_strategy(self) -> StrategiesVector:
|
def register_strategy(self) -> StrategiesVector:
|
||||||
'''
|
'''
|
||||||
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
|
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
|
||||||
|
|
||||||
Output:
|
Output:
|
||||||
|
|
||||||
|
@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
|
|||||||
from torch.fx.node import Node
|
from torch.fx.node import Node
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
|
|
||||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
from .._utils import generate_resharding_costs, generate_sharding_spec
|
from .._utils import generate_resharding_costs, generate_sharding_spec
|
||||||
from colossalai.auto_parallel.solver.constants import *
|
from colossalai.auto_parallel.solver.constants import *
|
||||||
|
@ -0,0 +1,35 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
|
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IntermediateStrategy:
|
||||||
|
"""
|
||||||
|
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
|
||||||
|
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name (str): name of the sharding strategy.
|
||||||
|
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
|
||||||
|
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
|
||||||
|
"""
|
||||||
|
name: str
|
||||||
|
dim_partition_dict: Dict[str, Dict[int, List[int]]]
|
||||||
|
all_reduce_axis: List[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyGenerator(ABC):
|
||||||
|
"""
|
||||||
|
StrategyGenerator is used to generate the same group of sharding strategies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device_mesh: DeviceMesh):
|
||||||
|
self.device_mesh = device_mesh
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def generate(self) -> List[IntermediateStrategy]:
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user