diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index 3a0cf584..a38ad5b0 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -296,7 +296,7 @@ class GPTJRCrossAttention(GPTJRAttention): new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) tensor = tensor.view(new_shape) - return tensor.permute(0, 2, 1) + return tensor.permute(0, 1, 3, 2) def _merge_heads(self, tensor, num_attention_heads, attn_head_size): @@ -304,6 +304,7 @@ class GPTJRCrossAttention(GPTJRAttention): Merges attn_head_size dim and num_attn_heads dim into hidden dim """ # tensor -> (bs, seq_len, num_attention_heads, head_dim) + tensor = tensor.permute(0, 2, 1, 3).contiguous() new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) return tensor.view(new_shape) @@ -316,10 +317,10 @@ class GPTJRCrossAttention(GPTJRAttention): head_mask=None, ): - # query -> (bs, seq_len, num_attention_heads, head_dim) - # key -> (bs, num_attention_heads, head_dim) - # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) - attn_weights = torch.matmul(query, key.transpose(-1, -2)) + # query -> (bs, num_attention_heads, seq_len, head_dim) + # key -> (bs, num_attention_heads, head_dim, neighbors) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) + attn_weights = torch.matmul(query, key) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) @@ -329,10 +330,10 @@ class GPTJRCrossAttention(GPTJRAttention): if head_mask is not None: attn_weights = attn_weights * head_mask - # value -> (bs, num_attention_heads, head_dim) - # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) + # value -> (bs, num_attention_heads, seq_len, head_dim) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) # attn_output -> (bs, num_attention_heads, seq_len, head_dim) - attn_output = torch.matmul(attn_weights, value.transpose(-1, -2)) + attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights @@ -362,8 +363,8 @@ class GPTJRCrossAttention(GPTJRAttention): value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim) - key = key.permute(0, 2, 1) - query = query.permute(0, 2, 1, 3) + value = value.permute(0, 3, 1, 2) + key = key.permute(0, 3, 2, 1) if layer_past is not None: past_key = layer_past[0] @@ -454,30 +455,25 @@ class GPTJRBlock(nn.Module): self_attention_residual = attn_output + feed_forward_hidden_states + residual # encoder_hidden_states -> (bs, knn, encoder_dim) + # may not need, can norm encodings encoder_normed = self.ln_2(encoder_hidden_states) - num_neighbors = encoder_normed.shape[1] - cross_attn_outputs = [] - for k in range(num_neighbors): - # cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim) - cross_attn_output = self.cross_attn( - residual, - encoder_hidden_states=encoder_normed[:, k, :], - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) - cross_attn_outputs.append(cross_attn_output[0]) + # cross_attn_outputs -> (bs, seq_len, dim) + cross_attn_output = self.cross_attn( + hidden_states, + encoder_hidden_states=encoder_normed, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) - cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1) # gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess? cross_attn_ff = self.cross_attn_mlp( - cross_attn_output + cross_attn_output[0] ) alpha = self.alpha if self.training else 0.5 - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: