[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)

* change for transformers loggers

* add forward for GPT2ForQuestionAnswering

* fix assert

* fix torchrec test
This commit is contained in:
Baizhou Zhang
2023-07-19 09:28:27 +08:00
committed by Hongxin Liu
parent d9be0472ef
commit 2a2eacfaf1
5 changed files with 147 additions and 11 deletions

View File

@@ -27,7 +27,6 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items():
inputs = data_gen_fn()
inputs = {k: v.cuda() for k, v in inputs.items()}
input_ids, _ = inputs['input_ids'], inputs['attention_mask']