From ca66d12d89042a6ca0a91caccd82f365eb522306 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Fri, 21 Apr 2023 14:23:33 +0000 Subject: [PATCH] fix: remove causal cross attn mask --- gpt4all/models/modeling_gpt_jr.py | 25 ------------------------- 1 file changed, 25 deletions(-) diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index d2e57857..3a0cf584 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -316,34 +316,11 @@ class GPTJRCrossAttention(GPTJRAttention): head_mask=None, ): - # compute causal mask from causal mask buffer - # since key and value don't have seq length, just use causal mask as normal - query_length = query.size(-2) - causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool) - - # Keep the attention weights computation in fp32 to avoid overflow issues - # TODO: do we need to do this with bfloat16?? - # query = query.to(torch.float32) - # key = key.to(torch.float32) - # query -> (bs, seq_len, num_attention_heads, head_dim) # key -> (bs, num_attention_heads, head_dim) # attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads) attn_weights = torch.matmul(query, key.transpose(-1, -2)) - mask_value = torch.finfo(attn_weights.dtype).min - # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. - # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` - mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) - attn_weights = torch.where(causal_mask, attn_weights, mask_value) - - attn_weights = attn_weights / self.scale_attn - - if attention_mask is not None: - # Apply the attention mask - # attn mask (1, 1, 1, seq_len) - attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2) - attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.to(value.dtype) attn_weights = self.attn_dropout(attn_weights) @@ -441,8 +418,6 @@ class GPTJRBlock(nn.Module): self.attn = GPTJRAttention(config) self.mlp = GPTJRMLP(inner_dim, config) - # TODO: fix for n neighbors - # for SBERT this is 384 self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon) self.cross_attn = GPTJRCrossAttention(config) self.cross_attn_mlp = GPTJRMLP(inner_dim, config)