mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-11 22:10:37 +00:00
fix bert unit test
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import BertConfig, BertForSequenceClassification
|
||||
from packaging import version
|
||||
|
||||
from torch.utils.data import SequentialSampler
|
||||
from transformers import BertConfig, BertForSequenceClassification
|
||||
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
@@ -39,14 +39,14 @@ def get_training_components():
|
||||
num_layer = 2
|
||||
|
||||
def bert_model_builder(checkpoint):
|
||||
config = BertConfig(
|
||||
gradient_checkpointing=checkpoint,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
num_attention_heads=num_head,
|
||||
max_position_embeddings=sequence_length,
|
||||
num_hidden_layers=num_layer,
|
||||
)
|
||||
config = BertConfig(gradient_checkpointing=checkpoint,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
num_attention_heads=num_head,
|
||||
max_position_embeddings=sequence_length,
|
||||
num_hidden_layers=num_layer,
|
||||
hidden_dropout_prob=0.,
|
||||
attention_probs_dropout_prob=0.)
|
||||
print('building BertForSequenceClassification model')
|
||||
|
||||
# adapting huggingface BertForSequenceClassification for single unitest calling interface
|
||||
|
Reference in New Issue
Block a user