support kit use for bert/gpt test (#4055)

* support kit use for bert test

* support kit test for gpt2
This commit is contained in:
FoolPlayer
2023-06-22 10:33:06 +08:00
committed by Frank Lee
parent f22ddacef0
commit 7740c55c55
7 changed files with 346 additions and 273 deletions

View File

@@ -1,7 +1,9 @@
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@@ -82,7 +84,6 @@ class GPT2Policy(Policy):
}
def new_model_class(self):
return self.model
def postprocess(self):
@@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# GPT22DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()