[test] Hotfix/fix some model test and refactor check util api (#4369)

* fix llama test

* fix test bug of bert, blip2, bloom, gpt2

* fix llama test

* fix opt test

* fix sam test

* fix sam test

* fix t5 test

* fix vit test

* fix whisper test

* fix whisper test

* polish code

* adjust allclose parameter

* Add mistakenly deleted code

* addjust allclose

* change loss function for some base model
This commit is contained in:
Bin Jia
2023-08-03 14:51:36 +08:00
committed by Hongxin Liu
parent c3ca53cf05
commit 5c6f183192
16 changed files with 135 additions and 336 deletions

View File

@@ -44,7 +44,8 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_lm = lambda x: x.loss
config = transformers.OPTConfig(
hidden_size=128,