mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 19:40:28 +00:00
[shardformer]fix gpt2 double head (#4663)
* [shardformer]fix gpt2 test [shardformer]fix gpt2 test [shardformer]fix gpt2 test * fix * [shardformer] add todo * [shardformer] add todo
This commit is contained in:
@@ -58,9 +58,27 @@ def data_gen_for_sequence_classification():
|
||||
|
||||
|
||||
def date_gen_for_double_heads():
|
||||
data = data_gen_for_lm()
|
||||
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
|
||||
return data
|
||||
num_choices = 2
|
||||
batch_size = 2
|
||||
input_ids = torch.tensor(
|
||||
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
|
||||
dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||
|
||||
mc_token_ids = torch.arange(0, num_choices, dtype=torch.int64)
|
||||
mc_token_ids = mc_token_ids.expand((batch_size, num_choices))
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = attention_mask.unsqueeze(1).expand(-1, num_choices, -1).contiguous()
|
||||
|
||||
inputs = {
|
||||
"input_ids": multiple_choice_inputs_ids,
|
||||
"mc_token_ids": mc_token_ids,
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"labels": multiple_choice_inputs_ids,
|
||||
"mc_labels": mc_labels,
|
||||
}
|
||||
return inputs
|
||||
|
||||
|
||||
# define output transform function
|
||||
@@ -98,14 +116,12 @@ model_zoo.register(name='transformers_gpt_lm',
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
# TODO The model training is failing, there is a bug in GPT2DoubleHeadsModel in transformers.
|
||||
# model_zoo.register(name='transformers_gpt_double_heads',
|
||||
# model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
# data_gen_fn=date_gen_for_double_heads,
|
||||
# output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
|
||||
# loss_fn=loss_fn,
|
||||
# model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
|
Reference in New Issue
Block a user