[shardformer] upgrade transformers to 4.39.3 (#5815)

* [shardformer]upgrade transformers for gpt2/gptj/whisper (#5807)

* [shardformer] fix modeling of gpt2 and gptj

* [shardformer] fix whisper modeling

* [misc] update requirements

---------

Co-authored-by: ver217 <lhx0217@gmail.com>

* [shardformer]upgrade transformers for mistral (#5808)

* upgrade transformers for mistral

* fix

* fix

* [shardformer]upgrade transformers for llama (#5809)

* update transformers

fix

* fix

* fix

* [inference] upgrade transformers (#5810)

* update transformers

fix

* fix

* fix

* fix

* fix

* [gemini] update transformers for gemini (#5814)

---------

Co-authored-by: ver217 <lhx0217@gmail.com>
This commit is contained in:
flybird11111
2024-06-14 10:59:33 +08:00
committed by GitHub
parent 3bcbba9262
commit 2ddf624a86
12 changed files with 257 additions and 240 deletions

View File

@@ -28,15 +28,22 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, K_H, D, dtype):
torch.manual_seed(10)
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D, dtype=dtype)
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
cos = cos.reshape((TOTAL_TOKENS, -1))
sin = sin.reshape((TOTAL_TOKENS, -1))
cos_2 = cos[:, : D // 2]
sin_2 = sin[:, : D // 2]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data

View File

@@ -43,15 +43,19 @@ def torch_rotary_emb(x, cos, sin):
def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype, use_new_kcache_layout):
TOTAL_TOKENS = BATCH_SIZE * SEQ_LEN
# our crafted op equals to Transformers
x0 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x1 = torch.randn(TOTAL_TOKENS, SEQ_LEN, D)
x0 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
x1 = torch.randn(BATCH_SIZE, H, SEQ_LEN, D, dtype=dtype)
emb = LlamaRotaryEmbedding(D)
cos, sin = emb(x0, TOTAL_TOKENS)
position_ids = torch.arange(TOTAL_TOKENS).reshape((BATCH_SIZE, SEQ_LEN))
cos, sin = emb(x0, position_ids)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin)
cos = cos.reshape((TOTAL_TOKENS, -1))
sin = sin.reshape((TOTAL_TOKENS, -1))
cos_2 = cos[:, :32]
sin_2 = sin[:, :32]
position_ids = torch.arange(TOTAL_TOKENS)
embd_x0, _ = apply_rotary_pos_emb(x0, x1, cos, sin, position_ids)
embd_stimulated_x = torch_rotary_emb(x0, cos_2, sin_2)
x2 = x0.transpose(1, 2).reshape(TOTAL_TOKENS, H, D)
embd_stimulated_x = torch_rotary_emb(x2, cos_2, sin_2)
embd_stimulated_x = embd_stimulated_x.reshape((BATCH_SIZE, SEQ_LEN, H, D)).transpose(1, 2)
assert torch.allclose(embd_x0, embd_stimulated_x)
# create data