[gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)

* fix

aaa

fix

fix

fix

* fix

* fix

* test ci

* fix ci

fix
This commit is contained in:
flybird11111
2023-12-08 11:10:51 +08:00
committed by GitHub
parent b397104438
commit 21aa5de00b
3 changed files with 59 additions and 2 deletions

View File

@@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss
config = transformers.GPTJConfig(
n_layer=2,
n_head=16,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,