[pipeline] rewrite bert tests and fix some bugs (#4409)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params

* rewrite bert test

* rewrite bert test

* fix some bugs

* del pipeline tests

* del pipeline tests

* del useless print

* del useless print

* rewrite data repeats
This commit is contained in:
Jianghai
2023-08-11 10:32:53 +08:00
committed by Hongxin Liu
parent d2cd48e0be
commit 7596e9ae08
4 changed files with 83 additions and 154 deletions

View File

@@ -104,7 +104,8 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
))
loss_fn = lambda x: x.loss
config = transformers.BertConfig(hidden_size=128,