[Shardformer] Downstream bert (#3979)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage

* add downstream model of bert

* remove unused code
This commit is contained in:
FoolPlayer
2023-06-15 17:56:51 +08:00
committed by Frank Lee
parent c1c672d0f0
commit f7774ec0f3
5 changed files with 161 additions and 43 deletions

View File

@@ -1,9 +1,19 @@
import copy
import os
import random
import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
BertLMHeadModel,
BertModel,
)
import colossalai
from colossalai.logging import disable_existing_loggers
@@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size):
def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
org_model = model(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
shardconfig).to('cuda')
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
return org_model, sharded_model
@@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model):
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
BertForSequenceClassification
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
check_forward(org_model, sharded_model)
if model in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@pytest.mark.dist