mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[test] fixed tests failed due to dtensor change (#4082)
* [test] fixed tests failed due to dtensor change * polish code
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
@@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
|
||||
# must turn on eval mode to ensure the output is consistent
|
||||
model.eval()
|
||||
|
||||
inputs = data_gen()
|
||||
|
||||
if ignore_data is not None:
|
||||
# drop the ignore_data key
|
||||
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
|
||||
|
||||
try:
|
||||
kwargs = data_gen()
|
||||
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
|
||||
meta_args = {k: v.to('meta') for k, v in inputs.items()}
|
||||
gm = symbolic_trace(model, meta_args=meta_args)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
|
||||
|
||||
# run forward
|
||||
inputs = data_gen()
|
||||
non_fx_out = model(**inputs)
|
||||
fx_out = gm(**inputs)
|
||||
|
||||
|
@@ -15,7 +15,7 @@ SEQ_LENGTH = 16
|
||||
def test_albert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
@@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_bert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -47,7 +47,7 @@ def test_diffusers():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
trace_and_compare(model_fn, data, output_transform_fn)
|
||||
torch.cuda.synchronize()
|
||||
|
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_gpt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
|
||||
# TODO: support the following models
|
||||
@@ -21,7 +21,7 @@ def test_gpt():
|
||||
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||
continue
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@@ -12,7 +12,7 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
@@ -12,9 +12,14 @@ from tests.kit.model_zoo import model_zoo
|
||||
def test_t5():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
if name == "transformers_t5_for_conditional_generation":
|
||||
# cannot trace for loss function yet
|
||||
# so we use a data gen which does not produce labels
|
||||
data_gen_fn = sub_registry.get('transformers_t5')[1]
|
||||
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Reference in New Issue
Block a user