mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -19,44 +19,52 @@ def data_gen_fn():
|
||||
|
||||
def data_gen_for_pretrain():
|
||||
inputs = data_gen_fn()
|
||||
inputs['labels'] = inputs['input_ids'].clone()
|
||||
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
return inputs
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.AlbertConfig(
|
||||
embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_albert',
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_pretraining',
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_masked_lm',
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_token_classification',
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert",
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_pretraining",
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_masked_lm",
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_sequence_classification",
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_token_classification",
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
# ===============================
|
||||
# Register multi-sentence ALBERT
|
||||
@@ -80,13 +88,17 @@ def data_gen_for_mcq():
|
||||
return encoding
|
||||
|
||||
|
||||
model_zoo.register(name='transformers_albert_for_question_answering',
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_multiple_choice',
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_question_answering",
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_multiple_choice",
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -28,7 +28,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def data_gen_for_pretraining():
|
||||
# pretraining data gen
|
||||
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
|
||||
data = data_gen_for_lm()
|
||||
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
# `labels` is the label for sequence classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -67,32 +67,276 @@ def data_gen_for_mcq():
|
||||
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442,
|
||||
1012, 102, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor([[[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
attention_mask = torch.tensor([[[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1
|
||||
],
|
||||
[
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
|
||||
]]])
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
101,
|
||||
1999,
|
||||
3304,
|
||||
1010,
|
||||
10733,
|
||||
2366,
|
||||
1999,
|
||||
5337,
|
||||
10906,
|
||||
1010,
|
||||
2107,
|
||||
2004,
|
||||
2012,
|
||||
1037,
|
||||
4825,
|
||||
1010,
|
||||
2003,
|
||||
3591,
|
||||
4895,
|
||||
14540,
|
||||
6610,
|
||||
2094,
|
||||
1012,
|
||||
102,
|
||||
2009,
|
||||
2003,
|
||||
8828,
|
||||
2007,
|
||||
1037,
|
||||
9292,
|
||||
1998,
|
||||
1037,
|
||||
5442,
|
||||
1012,
|
||||
102,
|
||||
102,
|
||||
5442,
|
||||
1012,
|
||||
102,
|
||||
102,
|
||||
],
|
||||
[
|
||||
101,
|
||||
1999,
|
||||
3304,
|
||||
1010,
|
||||
10733,
|
||||
2366,
|
||||
1999,
|
||||
5337,
|
||||
10906,
|
||||
1010,
|
||||
2107,
|
||||
2004,
|
||||
2012,
|
||||
1037,
|
||||
4825,
|
||||
1010,
|
||||
2003,
|
||||
3591,
|
||||
4895,
|
||||
14540,
|
||||
6610,
|
||||
2094,
|
||||
1012,
|
||||
102,
|
||||
2009,
|
||||
2003,
|
||||
8828,
|
||||
2096,
|
||||
2218,
|
||||
1999,
|
||||
1996,
|
||||
2192,
|
||||
1012,
|
||||
102,
|
||||
0,
|
||||
0,
|
||||
1012,
|
||||
102,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
token_type_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
attention_mask = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
],
|
||||
[
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
],
|
||||
]
|
||||
]
|
||||
)
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
@@ -103,9 +347,9 @@ def data_gen_for_qa():
|
||||
# no need for labels and use start and end position instead
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
data["start_positions"] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
data["end_positions"] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
@@ -114,69 +358,90 @@ output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0)
|
||||
config = transformers.BertConfig(
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0,
|
||||
)
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_bert',
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_pretraining',
|
||||
model_fn=lambda: transformers.BertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretraining,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_lm_head_model',
|
||||
model_fn=lambda: transformers.BertLMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_masked_lm',
|
||||
model_fn=lambda: transformers.BertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_token_classification',
|
||||
model_fn=lambda: transformers.BertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_next_sentence',
|
||||
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_mcq',
|
||||
model_fn=lambda: transformers.BertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_question_answering',
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_bert",
|
||||
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_pretraining",
|
||||
model_fn=lambda: transformers.BertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretraining,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_lm_head_model",
|
||||
model_fn=lambda: transformers.BertLMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_masked_lm",
|
||||
model_fn=lambda: transformers.BertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_sequence_classification",
|
||||
model_fn=lambda: transformers.BertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_token_classification",
|
||||
model_fn=lambda: transformers.BertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_next_sentence",
|
||||
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_mcq",
|
||||
model_fn=lambda: transformers.BertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bert_for_question_answering",
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -47,16 +47,20 @@ config.qformer_config.hidden_dropout_prob = 0
|
||||
config.text_config.dropout = 0
|
||||
|
||||
# register the blip2 variants
|
||||
model_zoo.register(name='transformers_blip2',
|
||||
model_fn=lambda: transformers.Blip2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_blip2",
|
||||
model_fn=lambda: transformers.Blip2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_blip2_conditional_gerneration',
|
||||
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_blip2_conditional_gerneration",
|
||||
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_blip2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -25,7 +25,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -33,14 +33,14 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -54,62 +54,69 @@ def data_gen_for_question_answering():
|
||||
|
||||
input_ids = torch.tensor(
|
||||
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
start_positions = torch.tensor([1], dtype=torch.int64)
|
||||
end_positions = torch.tensor([10], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
start_positions=start_positions,
|
||||
end_positions=end_positions)
|
||||
return dict(
|
||||
input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions
|
||||
)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
|
||||
torch.ones_like(x.last_hidden_state))
|
||||
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_causal_lm = lambda x: x.loss
|
||||
loss_fn_for_classification = lambda x: x.loss
|
||||
loss_fn_for_question_answering = lambda x: x.loss
|
||||
|
||||
config = transformers.BloomConfig(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=250880,
|
||||
hidden_dropout=0,
|
||||
attention_dropout=0,
|
||||
hidden_size=64,
|
||||
pad_token_id=50256)
|
||||
config = transformers.BloomConfig(
|
||||
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256
|
||||
)
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_bloom',
|
||||
model_fn=lambda: transformers.BloomModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bloom_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_causal_lm',
|
||||
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_token_classification',
|
||||
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bloom_for_question_answering',
|
||||
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_question_answering,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_bloom",
|
||||
model_fn=lambda: transformers.BloomModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bloom_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_causal_lm",
|
||||
model_fn=lambda: transformers.BloomForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_causal_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_sequence_classification",
|
||||
model_fn=lambda: transformers.BloomForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_token_classification",
|
||||
model_fn=lambda: transformers.BloomForTokenClassification(config),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_bloom_for_question_answering",
|
||||
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_question_answering,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
@@ -21,8 +20,8 @@ def data_gen_for_conditional_generation():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -30,29 +29,36 @@ def data_gen_for_conditional_generation():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
|
||||
torch.ones_like(x.last_hidden_state))
|
||||
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = ChatGLMConfig(num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32)
|
||||
config = ChatGLMConfig(
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_chatglm',
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm",
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -27,7 +27,7 @@ def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
data["labels"] = data["input_ids"].clone()
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ def data_gen_for_question_answering():
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
data["start_positions"] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
data["end_positions"] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
@@ -46,14 +46,14 @@ def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -62,7 +62,8 @@ def date_gen_for_double_heads():
|
||||
batch_size = 2
|
||||
input_ids = torch.tensor(
|
||||
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
|
||||
dtype=torch.int64)
|
||||
dtype=torch.int64,
|
||||
)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
|
||||
|
||||
@@ -85,58 +86,73 @@ def date_gen_for_double_heads():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256)
|
||||
config = transformers.GPT2Config(
|
||||
n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification",
|
||||
pad_token_id=50256,
|
||||
)
|
||||
|
||||
config_for_token_classification = copy.deepcopy(config)
|
||||
config_for_token_classification.num_labels = 2
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
model_fn=lambda: transformers.GPT2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_lm',
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_question_answering',
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_gpt",
|
||||
model_fn=lambda: transformers.GPT2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_lm",
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_double_heads",
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=date_gen_for_double_heads,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=lambda x: x.loss + x.mc_loss,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_question_answering",
|
||||
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_token_classification",
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_gpt_for_sequence_classification",
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -4,7 +4,8 @@ import transformers
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
from transformers import LlamaConfig
|
||||
|
||||
HAS_LLAMA = True
|
||||
except ImportError:
|
||||
HAS_LLAMA = False
|
||||
@@ -33,8 +34,8 @@ if HAS_LLAMA:
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
@@ -45,12 +46,14 @@ if HAS_LLAMA:
|
||||
loss_fn_for_casual_lm = lambda output: output.loss
|
||||
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||
|
||||
config = LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
config = LlamaConfig(
|
||||
num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16,
|
||||
)
|
||||
|
||||
if hasattr(config, "pad_token_id"):
|
||||
config.pad_token_id = config.eos_token_id
|
||||
@@ -59,21 +62,27 @@ if HAS_LLAMA:
|
||||
# transformers.LlamaModel,
|
||||
# transformers.LlamaForCausalLM,
|
||||
# transformers.LlamaForSequenceClassification,
|
||||
model_zoo.register(name='transformers_llama',
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_casual_lm',
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_sequence_classification',
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_llama",
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_casual_lm",
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_llama_for_sequence_classification",
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -20,8 +20,8 @@ def data_gen_for_causal_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
labels = data["input_ids"].clone()
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -29,8 +29,8 @@ def data_gen_for_sequence_classification():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = torch.tensor([1])
|
||||
data["input_ids"].clone()
|
||||
data["labels"] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
@@ -38,14 +38,15 @@ def data_gen_for_question_answering():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['start_positions'] = torch.tensor([0])
|
||||
data['end_positions'] = torch.tensor([1])
|
||||
data["start_positions"] = torch.tensor([0])
|
||||
data["end_positions"] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
|
||||
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
|
||||
)
|
||||
loss_fn_for_lm = lambda x: x.loss
|
||||
config = transformers.OPTConfig(
|
||||
hidden_size=128,
|
||||
@@ -57,24 +58,30 @@ config = transformers.OPTConfig(
|
||||
# register the following models
|
||||
# transformers.OPTModel,
|
||||
# transformers.OPTForCausalLM,
|
||||
model_zoo.register(name='transformers_opt',
|
||||
model_fn=lambda: transformers.OPTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_opt_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_causal_lm',
|
||||
model_fn=lambda: transformers.OPTForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_question_answering',
|
||||
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_opt",
|
||||
model_fn=lambda: transformers.OPTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_opt_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_opt_for_causal_lm",
|
||||
model_fn=lambda: transformers.OPTForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_opt_for_question_answering",
|
||||
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
# TODO The loss and gradient check in the test are failing, to be fixed.
|
||||
# model_zoo.register(name='transformers_opt_for_sequence_classification',
|
||||
|
@@ -28,10 +28,12 @@ def data_gen():
|
||||
original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64)
|
||||
reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64)
|
||||
input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64)
|
||||
return dict(pixel_values=pixel_values,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
input_points=input_points)
|
||||
return dict(
|
||||
pixel_values=pixel_values,
|
||||
original_sizes=original_sizes,
|
||||
reshaped_input_sizes=reshaped_input_sizes,
|
||||
input_points=input_points,
|
||||
)
|
||||
|
||||
|
||||
# define output transform function
|
||||
@@ -44,9 +46,11 @@ config = transformers.SamConfig()
|
||||
config.vision_config.num_hidden_layers = 2
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_sam',
|
||||
model_fn=lambda: transformers.SamModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_sam",
|
||||
model_fn=lambda: transformers.SamModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -27,7 +27,7 @@ def data_gen_for_conditional_generation():
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
data["labels"] = labels
|
||||
return data
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ def data_gen_for_t5_model():
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
data["decoder_input_ids"] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
||||
@@ -55,21 +55,27 @@ config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decode
|
||||
# transformers.T5Model,
|
||||
# transformers.T5ForConditionalGeneration,
|
||||
# transformers.T5EncoderModel,
|
||||
model_zoo.register(name='transformers_t5',
|
||||
model_fn=lambda: transformers.T5Model(config),
|
||||
data_gen_fn=data_gen_for_t5_model,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_t5_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_for_conditional_generation',
|
||||
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_conditional_generation,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_encoder_model',
|
||||
model_fn=lambda: transformers.T5EncoderModel(config),
|
||||
data_gen_fn=data_gen_for_encoder_only,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_t5",
|
||||
model_fn=lambda: transformers.T5Model(config),
|
||||
data_gen_fn=data_gen_for_t5_model,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_t5_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_t5_for_conditional_generation",
|
||||
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_conditional_generation,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_t5_encoder_model",
|
||||
model_fn=lambda: transformers.T5EncoderModel(config),
|
||||
data_gen_fn=data_gen_for_encoder_only,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -18,15 +18,15 @@ def data_gen():
|
||||
|
||||
def data_gen_for_image_classification():
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0])
|
||||
data["labels"] = torch.tensor([0])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_masked_image_modeling():
|
||||
data = data_gen()
|
||||
num_patches = (config.image_size // config.patch_size)**2
|
||||
num_patches = (config.image_size // config.patch_size) ** 2
|
||||
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
|
||||
data['bool_masked_pos'] = bool_masked_pos
|
||||
data["bool_masked_pos"] = bool_masked_pos
|
||||
return data
|
||||
|
||||
|
||||
@@ -42,23 +42,29 @@ loss_fn_for_masked_image_modeling = lambda x: x.loss
|
||||
# transformers.ViTModel,
|
||||
# transformers.ViTForMaskedImageModeling,
|
||||
# transformers.ViTForImageClassification,
|
||||
model_zoo.register(name='transformers_vit',
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit",
|
||||
model_fn=lambda: transformers.ViTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_vit_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_masked_image_modeling",
|
||||
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
|
||||
data_gen_fn=data_gen_for_masked_image_modeling,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_masked_image_modeling,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_vit_for_image_classification',
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_vit_for_image_classification",
|
||||
model_fn=lambda: transformers.ViTForImageClassification(config),
|
||||
data_gen_fn=data_gen_for_image_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_image_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
@@ -33,7 +33,7 @@ def data_gen_for_conditional_generation():
|
||||
# or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
|
||||
# only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64)
|
||||
data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ def data_gen_for_audio_classification():
|
||||
# `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
||||
# `WhisperForAudioClassification` does not need `decoder_input_ids`
|
||||
data = data_gen()
|
||||
data.pop('decoder_input_ids')
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
data.pop("decoder_input_ids")
|
||||
data["labels"] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
@@ -69,23 +69,29 @@ config = transformers.WhisperConfig(
|
||||
)
|
||||
|
||||
# register the Whisper variants
|
||||
model_zoo.register(name='transformers_whisper',
|
||||
model_fn=lambda: transformers.WhisperModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper",
|
||||
model_fn=lambda: transformers.WhisperModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_whisper_for_conditional_generation',
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper_for_conditional_generation",
|
||||
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_whisper_for_audio_classification',
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_whisper_for_audio_classification",
|
||||
model_fn=lambda: transformers.WhisperForAudioClassification(config),
|
||||
data_gen_fn=data_gen_for_audio_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_attr,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
Reference in New Issue
Block a user