fix bert unit test

This commit is contained in:
ver217
2022-03-09 13:38:20 +08:00
committed by Frank Lee
parent 5663616921
commit f5f0ad266e
2 changed files with 13 additions and 15 deletions

View File

@@ -1,9 +1,9 @@
import torch
import transformers
from transformers import BertConfig, BertForSequenceClassification
from packaging import version
from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs
@@ -39,14 +39,14 @@ def get_training_components():
num_layer = 2
def bert_model_builder(checkpoint):
config = BertConfig(
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
)
config = BertConfig(gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.)
print('building BertForSequenceClassification model')
# adapting huggingface BertForSequenceClassification for single unitest calling interface