[pipeline] All bert models (#4233)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision

* add pure pipeline test

* finish some bert models

* finish all bert models

* finish bert tests

* fix bugs

* fix bugs

* fix test pipeline

* fix data gen for qa

* update the set pipeline forward

* shared params

* fix bugs
This commit is contained in:
Jianghai
2023-07-17 16:12:20 +08:00
committed by Hongxin Liu
parent a14d352088
commit e7cc62d735
13 changed files with 988 additions and 144 deletions

View File

@@ -1 +1 @@
from .torchrec import *
#from .torchrec import *

View File

@@ -87,6 +87,17 @@ def data_gen_for_mcq():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
def data_gen_for_qa():
# generating data for question answering
# no need for labels and use start and end position instead
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data['start_positions'] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data['end_positions'] = end_positions
return data
# define output transform function
output_transform_fn = lambda x: x
@@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_question_answering',
model_fn=lambda: transformers.BertForQuestionAnswering(config),
data_gen_fn=data_gen_for_qa,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))