Merge branch 'main' into feature/shardformer

This commit is contained in:
Hongxin Liu
2023-09-04 23:43:13 +08:00
committed by GitHub
138 changed files with 4664 additions and 4219 deletions

View File

@@ -17,6 +17,13 @@ def data_gen_fn():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_pretrain():
inputs = data_gen_fn()
inputs['labels'] = inputs['input_ids'].clone()
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
return inputs
output_transform_fn = lambda x: x
config = transformers.AlbertConfig(embedding_size=128,
@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
intermediate_size=256)
model_zoo.register(name='transformers_albert',
model_fn=lambda: transformers.AlbertModel(config),
model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_pretraining',
model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn,
data_gen_fn=data_gen_for_pretrain,
output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_masked_lm',
model_fn=lambda: transformers.AlbertForMaskedLM(config),

View File

@@ -113,6 +113,7 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss
@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config),
model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,

View File

@@ -57,6 +57,12 @@ def data_gen_for_sequence_classification():
return data
def date_gen_for_double_heads():
data = data_gen_for_lm()
data['mc_labels'] = torch.zeros(data['input_ids'].shape[0], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
@@ -94,8 +100,8 @@ model_zoo.register(name='transformers_gpt_lm',
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
data_gen_fn=date_gen_for_double_heads,
output_transform_fn=lambda x: dict(loss=x.loss + x.mc_loss),
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_question_answering',