From 288304028645f545b1eb0a6ffda46143ec92c422 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Tue, 26 Apr 2022 13:33:27 +0800 Subject: [PATCH] [example] change qkv processing (#870) --- model_zoo/gpt/gpt.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/model_zoo/gpt/gpt.py b/model_zoo/gpt/gpt.py index 7384cc3b4..f684316b6 100644 --- a/model_zoo/gpt/gpt.py +++ b/model_zoo/gpt/gpt.py @@ -89,13 +89,14 @@ class GPTSelfAttention(nn.Module): def forward(self, x, attention_mask=None): qkv = self.query_key_value(x) - all_head_size = qkv.shape[-1] // 3 - num_attention_heads = divide(all_head_size, self.attention_head_size) - new_qkv_shape = qkv.shape[:-1] + \ - (num_attention_heads, 3 * self.attention_head_size) - qkv = qkv.view(new_qkv_shape) - qkv = qkv.permute((0, 2, 1, 3)) q, k, v = torch.chunk(qkv, 3, dim=-1) + all_head_size = q.shape[-1] + num_attention_heads = divide(all_head_size, self.attention_head_size) + new_shape = q.shape[:-1] + \ + (num_attention_heads, self.attention_head_size) + q = q.view(new_shape).permute((0, 2, 1, 3)).contiguous() + k = k.view(new_shape).permute((0, 2, 1, 3)).contiguous() + v = v.view(new_shape).permute((0, 2, 1, 3)).contiguous() x = torch.matmul(q, k.transpose(-1, -2))