integrate with dist layer (#4011)

This commit is contained in:
FoolPlayer
2023-06-16 11:23:30 +08:00
committed by Frank Lee
parent 015af592f8
commit dfca9678fa
3 changed files with 42 additions and 24 deletions

View File

@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ParallelModule():
def __init__(self):
pass
class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None):
@@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=ParallelModule,
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
),
])
}