mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[pipeline] support shardformer for GPT2ForQuestionAnswering & complete pipeline support for GPT2 (#4245)
* change for transformers loggers * add forward for GPT2ForQuestionAnswering * fix assert * fix torchrec test
This commit is contained in:
committed by
Hongxin Liu
parent
d9be0472ef
commit
2a2eacfaf1
@@ -29,6 +29,17 @@ def data_gen_for_lm():
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_question_answering():
|
||||
# question answering data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
@@ -82,6 +93,12 @@ model_zoo.register(name='transformers_gpt_double_heads',
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
|
Reference in New Issue
Block a user