[shardformer] add gpt2 test and layer class refactor (#4041)

* add gpt2 test and layer class refactor

* add dropout in gpt2 policy
This commit is contained in:
FoolPlayer
2023-06-20 11:45:16 +08:00
committed by Frank Lee
parent d857f3dbba
commit 4021b9a8a2
14 changed files with 1400 additions and 840 deletions

View File

@@ -1,7 +1,7 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, BertLMPredictionHead
import colossalai.shardformer.layer.layers as col_nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D
from ..utils import getattr_, setattr_
@@ -87,15 +87,9 @@ class BertPolicy(Policy):
def new_model_class(self):
# do nothing
return None
return self.model
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
@@ -127,6 +121,15 @@ class BertForPretrainingPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
@@ -149,6 +152,15 @@ class BertForMaskedLMPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@@ -171,6 +183,15 @@ class BertLMHeadModelPolicy(BertPolicy):
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):