mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit * polish code
This commit is contained in:
@@ -47,7 +47,7 @@ def test_diffusers():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
trace_and_compare(model_fn, data, output_transform_fn)
|
||||
torch.cuda.synchronize()
|
||||
@@ -60,7 +60,7 @@ def test_torch_diffusers():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
model = model_fn()
|
||||
output = model(**data)
|
||||
|
||||
@@ -56,7 +56,7 @@ def test_timm_models():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
|
||||
@@ -16,7 +16,7 @@ def test_torchaudio_models():
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
|
||||
Reference in New Issue
Block a user