mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[pipeline] All bert models (#4233)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
* add pure pipeline test
* finish some bert models
* finish all bert models
* finish bert tests
* fix bugs
* fix bugs
* fix test pipeline
* fix data gen for qa
* update the set pipeline forward
* shared params
* fix bugs
This commit is contained in:
@@ -1 +1 @@
|
||||
from .torchrec import *
|
||||
#from .torchrec import *
|
||||
|
@@ -87,6 +87,17 @@ def data_gen_for_mcq():
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
def data_gen_for_qa():
|
||||
# generating data for question answering
|
||||
# no need for labels and use start and end position instead
|
||||
data = data_gen()
|
||||
start_positions = torch.tensor([0], dtype=torch.int64)
|
||||
data['start_positions'] = start_positions
|
||||
end_positions = torch.tensor([1], dtype=torch.int64)
|
||||
data['end_positions'] = end_positions
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
@@ -150,3 +161,9 @@ model_zoo.register(name='transformers_bert_for_mcq',
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_question_answering',
|
||||
model_fn=lambda: transformers.BertForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_qa,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
Reference in New Issue
Block a user