[Tensor] fix test_model (#916)

* polish test_model

* polish
This commit is contained in:
Ziyue Jiang
2022-05-06 18:06:22 +08:00
committed by GitHub
parent ed6426c300
commit dfaff4e243
2 changed files with 30 additions and 13 deletions

View File

@@ -37,9 +37,11 @@ def get_training_components():
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 30524
def bert_model_builder(checkpoint):
config = BertConfig(gradient_checkpointing=checkpoint,
config = BertConfig(vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,