mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 13:30:19 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -19,44 +19,52 @@ def data_gen_fn():
|
||||
|
||||
def data_gen_for_pretrain():
|
||||
inputs = data_gen_fn()
|
||||
inputs['labels'] = inputs['input_ids'].clone()
|
||||
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
inputs["labels"] = inputs["input_ids"].clone()
|
||||
inputs["sentence_order_label"] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
|
||||
return inputs
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.AlbertConfig(embedding_size=128,
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256)
|
||||
config = transformers.AlbertConfig(
|
||||
embedding_size=128, hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256
|
||||
)
|
||||
|
||||
model_zoo.register(name='transformers_albert',
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_pretraining',
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_masked_lm',
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_token_classification',
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert",
|
||||
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_pretraining",
|
||||
model_fn=lambda: transformers.AlbertForPreTraining(config),
|
||||
data_gen_fn=data_gen_for_pretrain,
|
||||
output_transform_fn=lambda x: dict(loss=x.loss),
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_masked_lm",
|
||||
model_fn=lambda: transformers.AlbertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_sequence_classification",
|
||||
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_token_classification",
|
||||
model_fn=lambda: transformers.AlbertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
# ===============================
|
||||
# Register multi-sentence ALBERT
|
||||
@@ -80,13 +88,17 @@ def data_gen_for_mcq():
|
||||
return encoding
|
||||
|
||||
|
||||
model_zoo.register(name='transformers_albert_for_question_answering',
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_albert_for_multiple_choice',
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_question_answering",
|
||||
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
model_zoo.register(
|
||||
name="transformers_albert_for_multiple_choice",
|
||||
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
Reference in New Issue
Block a user