mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[pipeline] rewrite bert tests and fix some bugs (#4409)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretraining * add bert_for_pretraining forward and policy * fix typos * cancel warning * change the imediate output to default dict * change the default output of get_shared_params * rewrite bert test * rewrite bert test * fix some bugs * del pipeline tests * del pipeline tests * del useless print * del useless print * rewrite data repeats
This commit is contained in:
@@ -104,7 +104,8 @@ def data_gen_for_qa():
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss funciton
|
||||
loss_fn_for_bert_model = lambda x: x.pooler_output.sum()
|
||||
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
|
||||
))
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128,
|
||||
|
Reference in New Issue
Block a user