[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)

* change for transformers loggers

* add forward for GPT2ForQuestionAnswering

* fix assert

* fix torchrec test
This commit is contained in:
Baizhou Zhang
2023-07-19 09:28:27 +08:00
committed by Hongxin Liu
parent d9be0472ef
commit 2a2eacfaf1
5 changed files with 147 additions and 11 deletions

View File

@@ -29,6 +29,17 @@ def data_gen_for_lm():
return data
def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
@@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,