mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-28 04:55:25 +00:00
support kit use for bert/gpt test (#4055)
* support kit use for bert test * support kit test for gpt2
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
@@ -82,7 +84,6 @@ class GPT2Policy(Policy):
|
||||
}
|
||||
|
||||
def new_model_class(self):
|
||||
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
@@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# GPT2LMHeadModel
|
||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT22DoubleHeadsModel
|
||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT2ForTokenClassification
|
||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# GPT2ForSequenceClassification
|
||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user