[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
This commit is contained in:
Frank Lee
2023-06-26 15:50:07 +08:00
parent 92f6791095
commit c4b1b65931
37 changed files with 233 additions and 289 deletions

View File

@@ -1,3 +1,5 @@
from typing import List
import torch
from numpy import isin
from torch.fx import GraphModule
@@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):
def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
# must turn on eval mode to ensure the output is consistent
model.eval()
inputs = data_gen()
if ignore_data is not None:
# drop the ignore_data key
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
try:
kwargs = data_gen()
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
meta_args = {k: v.to('meta') for k, v in inputs.items()}
gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
# run forward
inputs = data_gen()
non_fx_out = model(**inputs)
fx_out = gm(**inputs)

View File

@@ -15,7 +15,7 @@ SEQ_LENGTH = 16
def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)

View File

@@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
if __name__ == '__main__':

View File

@@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()

View File

@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_gpt():
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
# TODO: support the following models
@@ -21,7 +21,7 @@ def test_gpt():
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
continue
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__':

View File

@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)

View File

@@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
if name == "transformers_t5_for_conditional_generation":
# cannot trace for loss function yet
# so we use a data gen which does not produce labels
data_gen_fn = sub_registry.get('transformers_t5')[1]
model = model_fn()
trace_model_and_compare_output(model, data_gen_fn)
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
if __name__ == '__main__':