[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -19,44 +19,52 @@ def data_gen_fn():
def data_gen_for_pretrain():
inputs = data_gen_fn()
inputs['labels'] = inputs['input_ids'].clone()
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
inputs["labels"] = inputs["input_ids"].clone()
inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
return inputs
output_transform_fn = lambda x: x
config = transformers.AlbertConfig(embedding_size=128,
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256)
config = transformers.AlbertConfig(
embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256
)
model_zoo.register(name='transformers_albert',
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_pretraining',
model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_for_pretrain,
output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_masked_lm',
model_fn=lambda: transformers.AlbertForMaskedLM(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_sequence_classification',
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_token_classification',
model_fn=lambda: transformers.AlbertForTokenClassification(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_albert",
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_albert_for_pretraining",
model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_for_pretrain,
output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_albert_for_masked_lm",
model_fn=lambda: transformers.AlbertForMaskedLM(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_albert_for_sequence_classification",
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_albert_for_token_classification",
model_fn=lambda: transformers.AlbertForTokenClassification(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
# ===============================
# Register multi-sentence ALBERT
@@ -80,13 +88,17 @@ def data_gen_for_mcq():
return encoding
model_zoo.register(name='transformers_albert_for_question_answering',
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_multiple_choice',
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_albert_for_question_answering",
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_albert_for_multiple_choice",
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -28,7 +28,7 @@ def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
data["labels"] = data["input_ids"].clone()
return data
@@ -36,7 +36,7 @@ def data_gen_for_pretraining():
# pretraining data gen
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
data = data_gen_for_lm()
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
data["next_sentence_label"] = torch.tensor([1], dtype=torch.int64)
return data
@@ -44,7 +44,7 @@ def data_gen_for_sequence_classification():
# sequence classification data gen
# `labels` is the label for sequence classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([1], dtype=torch.int64)
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
@@ -52,7 +52,7 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data["labels"] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
@@ -67,32 +67,276 @@ def data_gen_for_mcq():
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102, 102, 5442,
1012, 102, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0, 0, 1012, 102, 0, 0
]]])
token_type_ids = torch.tensor([[[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1
],
[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
]]])
attention_mask = torch.tensor([[[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1
],
[
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0
]]])
input_ids = torch.tensor(
[
[
[
101,
1999,
3304,
1010,
10733,
2366,
1999,
5337,
10906,
1010,
2107,
2004,
2012,
1037,
4825,
1010,
2003,
3591,
4895,
14540,
6610,
2094,
1012,
102,
2009,
2003,
8828,
2007,
1037,
9292,
1998,
1037,
5442,
1012,
102,
102,
5442,
1012,
102,
102,
],
[
101,
1999,
3304,
1010,
10733,
2366,
1999,
5337,
10906,
1010,
2107,
2004,
2012,
1037,
4825,
1010,
2003,
3591,
4895,
14540,
6610,
2094,
1012,
102,
2009,
2003,
8828,
2096,
2218,
1999,
1996,
2192,
1012,
102,
0,
0,
1012,
102,
0,
0,
],
]
]
)
token_type_ids = torch.tensor(
[
[
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
],
[
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
0,
1,
1,
0,
0,
],
]
]
)
attention_mask = torch.tensor(
[
[
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
],
[
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
1,
0,
0,
1,
1,
0,
0,
],
]
]
)
labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
@@ -103,9 +347,9 @@ def data_gen_for_qa():
# no need for labels and use start and end position instead
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
data["start_positions"] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
data["end_positions"] = end_positions
return data
@@ -114,69 +358,90 @@ output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
config = transformers.BertConfig(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
hidden_dropout_prob=0,
attention_probs_dropout_prob=0)
config = transformers.BertConfig(
hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
hidden_dropout_prob=0,
attention_probs_dropout_prob=0,
)
# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_pretraining',
model_fn=lambda: transformers.BertForPreTraining(config),
data_gen_fn=data_gen_for_pretraining,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_lm_head_model',
model_fn=lambda: transformers.BertLMHeadModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_masked_lm',
model_fn=lambda: transformers.BertForMaskedLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_sequence_classification',
model_fn=lambda: transformers.BertForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_token_classification',
model_fn=lambda: transformers.BertForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_next_sentence',
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_mcq',
model_fn=lambda: transformers.BertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_question_answering',
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_bert",
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_pretraining",
model_fn=lambda: transformers.BertForPreTraining(config),
data_gen_fn=data_gen_for_pretraining,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_lm_head_model",
model_fn=lambda: transformers.BertLMHeadModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_masked_lm",
model_fn=lambda: transformers.BertForMaskedLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_sequence_classification",
model_fn=lambda: transformers.BertForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_token_classification",
model_fn=lambda: transformers.BertForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_next_sentence",
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_mcq",
model_fn=lambda: transformers.BertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bert_for_question_answering",
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -47,16 +47,20 @@ config.qformer_config.hidden_dropout_prob = 0
config.text_config.dropout = 0
# register the blip2 variants
model_zoo.register(name='transformers_blip2',
model_fn=lambda: transformers.Blip2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_blip2",
model_fn=lambda: transformers.Blip2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name='transformers_blip2_conditional_gerneration',
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_blip2_conditional_gerneration",
model_fn=lambda: transformers.Blip2ForConditionalGeneration(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_blip2_model,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -25,7 +25,7 @@ def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
data["labels"] = data["input_ids"].clone()
return data
@@ -33,14 +33,14 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data['labels'] = torch.tensor([0], dtype=torch.int64)
data["labels"] = torch.tensor([0], dtype=torch.int64)
return data
@@ -54,62 +54,69 @@ def data_gen_for_question_answering():
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
dtype=torch.int64)
dtype=torch.int64,
)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
return dict(input_ids=input_ids,
attention_mask=attention_mask,
start_positions=start_positions,
end_positions=end_positions)
return dict(
input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions
)
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
config = transformers.BloomConfig(n_layer=2,
n_head=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64,
pad_token_id=50256)
config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256
)
# register the following models
model_zoo.register(name='transformers_bloom',
model_fn=lambda: transformers.BloomModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bloom_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_causal_lm',
model_fn=lambda: transformers.BloomForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_sequence_classification',
model_fn=lambda: transformers.BloomForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_token_classification',
model_fn=lambda: transformers.BloomForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bloom_for_question_answering',
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_question_answering,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_bloom",
model_fn=lambda: transformers.BloomModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bloom_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bloom_for_causal_lm",
model_fn=lambda: transformers.BloomForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bloom_for_sequence_classification",
model_fn=lambda: transformers.BloomForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bloom_for_token_classification",
model_fn=lambda: transformers.BloomForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_bloom_for_question_answering",
model_fn=lambda: transformers.BloomForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_question_answering,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -1,5 +1,4 @@
import torch
import transformers
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
@@ -21,8 +20,8 @@ def data_gen_for_conditional_generation():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
labels = data["input_ids"].clone()
data["labels"] = labels
return data
@@ -30,29 +29,36 @@ def data_gen_for_conditional_generation():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state,
torch.ones_like(x.last_hidden_state))
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
config = ChatGLMConfig(num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
rmsnorm=True,
original_rope=True,
use_cache=True,
torch_dtype=torch.float32)
config = ChatGLMConfig(
num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
rmsnorm=True,
original_rope=True,
use_cache=True,
torch_dtype=torch.float32,
)
model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_chatglm_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_chatglm",
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_chatglm_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -27,7 +27,7 @@ def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
data["labels"] = data["input_ids"].clone()
return data
@@ -36,9 +36,9 @@ def data_gen_for_question_answering():
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
data["start_positions"] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
data["end_positions"] = end_positions
return data
@@ -46,14 +46,14 @@ def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 1]], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data['labels'] = torch.tensor([1], dtype=torch.int64)
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
@@ -62,7 +62,8 @@ def date_gen_for_double_heads():
batch_size = 2
input_ids = torch.tensor(
[[15496, 11, 616, 3290, 318, 13779, 318, 13779], [15496, 11, 616, 3290, 318, 13779, 318, 13779]],
dtype=torch.int64)
dtype=torch.int64,
)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
mc_labels = torch.zeros(input_ids.shape[0], dtype=torch.int64)
@@ -85,58 +86,73 @@ def date_gen_for_double_heads():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
config = transformers.GPT2Config(n_layer=2,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256)
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256,
)
config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2
# register the following models
model_zoo.register(name='transformers_gpt',
model_fn=lambda: transformers.GPT2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_gpt2_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_lm',
model_fn=lambda: transformers.GPT2LMHeadModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=output_transform_fn,
loss_fn=lambda x: x.loss + x.mc_loss,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_gpt",
model_fn=lambda: transformers.GPT2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_gpt2_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_lm",
model_fn=lambda: transformers.GPT2LMHeadModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_double_heads",
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=output_transform_fn,
loss_fn=lambda x: x.loss + x.mc_loss,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_question_answering",
model_fn=lambda: transformers.GPT2ForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_token_classification",
model_fn=lambda: transformers.GPT2ForTokenClassification(config_for_token_classification),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gpt_for_sequence_classification",
model_fn=lambda: transformers.GPT2ForSequenceClassification(config_for_token_classification),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -4,7 +4,8 @@ import transformers
from ..registry import ModelAttribute, model_zoo
try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers import LlamaConfig
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False
@@ -33,8 +34,8 @@ if HAS_LLAMA:
# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
labels = data["input_ids"].clone()
data["labels"] = labels
return data
# transform the output to a dict
@@ -45,12 +46,14 @@ if HAS_LLAMA:
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)
config = LlamaConfig(
num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16,
)
if hasattr(config, "pad_token_id"):
config.pad_token_id = config.eos_token_id
@@ -59,21 +62,27 @@ if HAS_LLAMA:
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_llama",
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_casual_lm",
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_llama_for_sequence_classification",
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -20,8 +20,8 @@ def data_gen_for_causal_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
labels = data["input_ids"].clone()
data["labels"] = labels
return data
@@ -29,8 +29,8 @@ def data_gen_for_sequence_classification():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = torch.tensor([1])
data["input_ids"].clone()
data["labels"] = torch.tensor([1])
return data
@@ -38,14 +38,15 @@ def data_gen_for_question_answering():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['start_positions'] = torch.tensor([0])
data['end_positions'] = torch.tensor([1])
data["start_positions"] = torch.tensor([0])
data["end_positions"] = torch.tensor([1])
return data
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
hidden_size=128,
@@ -57,24 +58,30 @@ config = transformers.OPTConfig(
# register the following models
# transformers.OPTModel,
# transformers.OPTForCausalLM,
model_zoo.register(name='transformers_opt',
model_fn=lambda: transformers.OPTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_opt_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_causal_lm',
model_fn=lambda: transformers.OPTForCausalLM(config),
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_opt_for_question_answering',
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_opt",
model_fn=lambda: transformers.OPTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_opt_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_opt_for_causal_lm",
model_fn=lambda: transformers.OPTForCausalLM(config),
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_opt_for_question_answering",
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
# TODO The loss and gradient check in the test are failing, to be fixed.
# model_zoo.register(name='transformers_opt_for_sequence_classification',

View File

@@ -28,10 +28,12 @@ def data_gen():
original_sizes = torch.tensor([[1764, 2646]], dtype=torch.int64)
reshaped_input_sizes = torch.tensor([[683, 1024]], dtype=torch.int64)
input_points = torch.tensor([[[[174.1497, 232.3129]]]], dtype=torch.float64)
return dict(pixel_values=pixel_values,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
input_points=input_points)
return dict(
pixel_values=pixel_values,
original_sizes=original_sizes,
reshaped_input_sizes=reshaped_input_sizes,
input_points=input_points,
)
# define output transform function
@@ -44,9 +46,11 @@ config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2
# register the BERT variants
model_zoo.register(name='transformers_sam',
model_fn=lambda: transformers.SamModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_sam",
model_fn=lambda: transformers.SamModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -27,7 +27,7 @@ def data_gen_for_conditional_generation():
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1, 644, 4598, 229, 19250, 5, 1, 229, 19250, 5, 1]]).long()
data['labels'] = labels
data["labels"] = labels
return data
@@ -36,7 +36,7 @@ def data_gen_for_t5_model():
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 5, 19, 1627, 5, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
data["decoder_input_ids"] = decoder_input_ids
return data
@@ -55,21 +55,27 @@ config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decode
# transformers.T5Model,
# transformers.T5ForConditionalGeneration,
# transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_t5",
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_t5_for_conditional_generation",
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_t5_encoder_model",
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -18,15 +18,15 @@ def data_gen():
def data_gen_for_image_classification():
data = data_gen()
data['labels'] = torch.tensor([0])
data["labels"] = torch.tensor([0])
return data
def data_gen_for_masked_image_modeling():
data = data_gen()
num_patches = (config.image_size // config.patch_size)**2
num_patches = (config.image_size // config.patch_size) ** 2
bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
data['bool_masked_pos'] = bool_masked_pos
data["bool_masked_pos"] = bool_masked_pos
return data
@@ -42,23 +42,29 @@ loss_fn_for_masked_image_modeling = lambda x: x.loss
# transformers.ViTModel,
# transformers.ViTForMaskedImageModeling,
# transformers.ViTForImageClassification,
model_zoo.register(name='transformers_vit',
model_fn=lambda: transformers.ViTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_vit_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_vit",
model_fn=lambda: transformers.ViTModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_vit_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name='transformers_vit_for_masked_image_modeling',
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
data_gen_fn=data_gen_for_masked_image_modeling,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_masked_image_modeling,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_vit_for_masked_image_modeling",
model_fn=lambda: transformers.ViTForMaskedImageModeling(config),
data_gen_fn=data_gen_for_masked_image_modeling,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_masked_image_modeling,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name='transformers_vit_for_image_classification',
model_fn=lambda: transformers.ViTForImageClassification(config),
data_gen_fn=data_gen_for_image_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_image_classification,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_vit_for_image_classification",
model_fn=lambda: transformers.ViTForImageClassification(config),
data_gen_fn=data_gen_for_image_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_image_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)

View File

@@ -33,7 +33,7 @@ def data_gen_for_conditional_generation():
# or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
# only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
data = data_gen()
data['labels'] = torch.tensor([[0, 1]], dtype=torch.int64)
data["labels"] = torch.tensor([[0, 1]], dtype=torch.int64)
return data
@@ -44,8 +44,8 @@ def data_gen_for_audio_classification():
# `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
# `WhisperForAudioClassification` does not need `decoder_input_ids`
data = data_gen()
data.pop('decoder_input_ids')
data['labels'] = torch.tensor([1], dtype=torch.int64)
data.pop("decoder_input_ids")
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
@@ -69,23 +69,29 @@ config = transformers.WhisperConfig(
)
# register the Whisper variants
model_zoo.register(name='transformers_whisper',
model_fn=lambda: transformers.WhisperModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_whisper",
model_fn=lambda: transformers.WhisperModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name='transformers_whisper_for_conditional_generation',
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_whisper_for_conditional_generation",
model_fn=lambda: transformers.WhisperForConditionalGeneration(config),
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(name='transformers_whisper_for_audio_classification',
model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(
name="transformers_whisper_for_audio_classification",
model_fn=lambda: transformers.WhisperForAudioClassification(config),
data_gen_fn=data_gen_for_audio_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_attr,
model_attribute=ModelAttribute(has_control_flow=True),
)