mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-19 01:36:37 +00:00
fix: remove causal cross attn mask
This commit is contained in:
parent
e62baf87f8
commit
ca66d12d89
@ -316,34 +316,11 @@ class GPTJRCrossAttention(GPTJRAttention):
|
|||||||
head_mask=None,
|
head_mask=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
# compute causal mask from causal mask buffer
|
|
||||||
# since key and value don't have seq length, just use causal mask as normal
|
|
||||||
query_length = query.size(-2)
|
|
||||||
causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool)
|
|
||||||
|
|
||||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
|
||||||
# TODO: do we need to do this with bfloat16??
|
|
||||||
# query = query.to(torch.float32)
|
|
||||||
# key = key.to(torch.float32)
|
|
||||||
|
|
||||||
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
||||||
# key -> (bs, num_attention_heads, head_dim)
|
# key -> (bs, num_attention_heads, head_dim)
|
||||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||||
|
|
||||||
mask_value = torch.finfo(attn_weights.dtype).min
|
|
||||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
|
||||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
|
||||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
|
||||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
|
||||||
|
|
||||||
attn_weights = attn_weights / self.scale_attn
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
# Apply the attention mask
|
|
||||||
# attn mask (1, 1, 1, seq_len)
|
|
||||||
attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2)
|
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
attn_weights = attn_weights.to(value.dtype)
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
attn_weights = self.attn_dropout(attn_weights)
|
attn_weights = self.attn_dropout(attn_weights)
|
||||||
@ -441,8 +418,6 @@ class GPTJRBlock(nn.Module):
|
|||||||
self.attn = GPTJRAttention(config)
|
self.attn = GPTJRAttention(config)
|
||||||
self.mlp = GPTJRMLP(inner_dim, config)
|
self.mlp = GPTJRMLP(inner_dim, config)
|
||||||
|
|
||||||
# TODO: fix for n neighbors
|
|
||||||
# for SBERT this is 384
|
|
||||||
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
|
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
|
||||||
self.cross_attn = GPTJRCrossAttention(config)
|
self.cross_attn = GPTJRCrossAttention(config)
|
||||||
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
||||||
|
Loading…
Reference in New Issue
Block a user