[fx] added testing for all gpt variants (#1210)

* [fx] added testing for all gpt variants

* polish code

* polish code
This commit is contained in:
Frank Lee
2022-07-06 14:03:13 +08:00
committed by GitHub
parent 189946c5c4
commit 2d13a45a3b
8 changed files with 136 additions and 72 deletions

View File

@@ -10,7 +10,7 @@ def test_coloproxy():
# create proxy
proxy = ColoProxy(node=node)
proxy.meta_tensor = torch.empty(4, 2, device='meta')
proxy.meta_data = torch.empty(4, 2, device='meta')
assert len(proxy) == 4
assert proxy.shape[0] == 4 and proxy.shape[1] == 2

View File

@@ -1,39 +1,11 @@
import transformers
import torch
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
BATCH_SIZE = 2
SEQ_LENGHT = 16
def trace_bert_and_compare_output(model, data_gen):
tracer = ColoTracer()
# make sure that the model is traceable
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
inputs = data_gen()
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
non_fx_out = model(**inputs)
fx_out = gm(**inputs)
for k in non_fx_out.keys():
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'
def test_single_sentence_bert():
MODEL_LIST = [
transformers.BertModel,
@@ -55,7 +27,7 @@ def test_single_sentence_bert():
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_bert_and_compare_output(model, data_gen)
trace_model_and_compare_output(model, data_gen)
def test_multi_sentence_bert():
@@ -69,7 +41,7 @@ def test_multi_sentence_bert():
return encoding
model = transformers.BertForNextSentencePrediction(config)
trace_bert_and_compare_output(model, data_gen_for_next_sentence)
trace_model_and_compare_output(model, data_gen_for_next_sentence)
def data_gen_for_qa():
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
@@ -77,7 +49,7 @@ def test_multi_sentence_bert():
return inputs
model = transformers.BertForQuestionAnswering(config)
trace_bert_and_compare_output(model, data_gen_for_qa)
trace_model_and_compare_output(model, data_gen_for_qa)
def data_gen_for_mcq():
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
@@ -88,7 +60,7 @@ def test_multi_sentence_bert():
return encoding
model = transformers.BertForMultipleChoice(config)
trace_bert_and_compare_output(model, data_gen_for_mcq)
trace_model_and_compare_output(model, data_gen_for_mcq)
if __name__ == '__main__':

View File

@@ -0,0 +1,33 @@
import transformers
import torch
from utils import trace_model_and_compare_output
BATCH_SIZE = 1
SEQ_LENGHT = 16
def test_gpt():
MODEL_LIST = [
transformers.GPT2Model,
transformers.GPT2LMHeadModel,
transformers.GPT2DoubleHeadsModel,
transformers.GPT2ForTokenClassification,
# transformers.GPT2ForSequenceClassification, # not supported yet
]
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs
for model_cls in MODEL_LIST:
model = model_cls(config=config)
trace_model_and_compare_output(model, data_gen)
if __name__ == '__main__':
test_gpt()

View File

@@ -0,0 +1,33 @@
from numpy import isin
import torch
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
def trace_model_and_compare_output(model, data_gen):
tracer = ColoTracer()
# make sure that the model is traceable
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# check output
inputs = data_gen()
# must turn on eval mode to ensure the output is consistent
gm.eval()
model.eval()
# run forward
non_fx_out = model(**inputs)
fx_out = gm(**inputs)
for k in non_fx_out.keys():
if torch.is_tensor(fx_out[k]):
assert torch.equal(fx_out[k], non_fx_out[k]), f'{model.__class__.__name__} has incorrect output {k}'