From 1559c0df410ebea6b0c64cee2362c8da7e93e8f4 Mon Sep 17 00:00:00 2001 From: ver217 Date: Mon, 21 Mar 2022 12:01:31 +0800 Subject: [PATCH] fix attn mask shape of gpt (#472) --- model_zoo/gpt/gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index dadbc152b..d7b5750fb 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -292,7 +292,7 @@ class GPT(nn.Module): # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # Adapted from huggingface if attention_mask is not None: - batch_size = x.shape[0] + batch_size = input_ids.shape[0] attention_mask = attention_mask.view(batch_size, -1) attention_mask = col_nn.partition_batch(attention_mask) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)