[shardformer] add linearconv1d test (#4067)

* add linearconv1d test

* add linearconv1d test
This commit is contained in:
FoolPlayer
2023-06-22 14:40:37 +08:00
committed by Frank Lee
parent 8eb09a4c69
commit 0803a61412
4 changed files with 122 additions and 34 deletions

View File

@@ -42,9 +42,6 @@ def check_gpt2(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
print(name)
# if name == 'transformers_gpt':
# continue
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)