mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-18 17:32:00 +00:00
fix: remove causal cross attn mask
This commit is contained in:
parent
e62baf87f8
commit
ca66d12d89
@ -316,34 +316,11 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
head_mask=None,
|
||||
):
|
||||
|
||||
# compute causal mask from causal mask buffer
|
||||
# since key and value don't have seq length, just use causal mask as normal
|
||||
query_length = query.size(-2)
|
||||
causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool)
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
# TODO: do we need to do this with bfloat16??
|
||||
# query = query.to(torch.float32)
|
||||
# key = key.to(torch.float32)
|
||||
|
||||
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim)
|
||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
attn_weights = attn_weights / self.scale_attn
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
# attn mask (1, 1, 1, seq_len)
|
||||
attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
@ -441,8 +418,6 @@ class GPTJRBlock(nn.Module):
|
||||
self.attn = GPTJRAttention(config)
|
||||
self.mlp = GPTJRMLP(inner_dim, config)
|
||||
|
||||
# TODO: fix for n neighbors
|
||||
# for SBERT this is 384
|
||||
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
|
||||
self.cross_attn = GPTJRCrossAttention(config)
|
||||
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
||||
|
Loading…
Reference in New Issue
Block a user