[fx] added testing for all albert variants (#1211)

This commit is contained in:
Frank Lee
2022-07-06 15:11:08 +08:00
committed by GitHub
parent 2d13a45a3b
commit 5da87ce35d
7 changed files with 193 additions and 4 deletions

View File

@@ -0,0 +1,65 @@
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def test_single_sentence_albert():
MODEL_LIST = [
transformers.AlbertModel,
transformers.AlbertForPreTraining,
transformers.AlbertForMaskedLM,
transformers.AlbertForSequenceClassification,
transformers.AlbertForTokenClassification,
]
config = transformers.AlbertConfig(embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
def test_multi_sentence_albert():
config = transformers.AlbertConfig(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
def data_gen_for_qa():
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
inputs = tokenizer(question, text, return_tensors="pt")
return inputs
model = transformers.AlbertForQuestionAnswering(config)
trace_model_and_compare_output(model, data_gen_for_qa)
def data_gen_for_mcq():
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
return encoding
model = transformers.AlbertForMultipleChoice(config)
trace_model_and_compare_output(model, data_gen_for_mcq)
if __name__ == '__main__':
test_single_sentence_albert()
test_multi_sentence_albert()

View File

@@ -0,0 +1,31 @@
import pytest
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('value is not aligned yet')
def test_opt():
MODEL_LIST = [
transformers.OPTModel,
transformers.OPTForCausalLM,
]
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_opt()

View File

@@ -0,0 +1,32 @@
import pytest
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
@pytest.mark.skip('value is not aligned yet')
def test_t5():
MODEL_LIST = [
transformers.T5Model,
transformers.T5ForConditionalGeneration,
transformers.T5EncoderModel,
]
config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_t5()

View File

@@ -30,4 +30,6 @@ def trace_model_and_compare_output(model, data_gen):
for k in non_fx_out.keys():
if torch.is_tensor(fx_out[k]):
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
assert torch.equal(
fx_out[k], non_fx_out[k]
), f'{model.__class__.__name__} has incorrect output {k}, expect {non_fx_out[k]}, but got {fx_out[k]}'