mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-18 16:07:51 +00:00
fix: forward works!
This commit is contained in:
parent
09cddbedc0
commit
df79fd64b0
@ -115,8 +115,7 @@ class GPTJRConfig(PretrainedConfig):
|
||||
bos_token_id=50256,
|
||||
eos_token_id=50256,
|
||||
tie_word_embeddings=False,
|
||||
encoder_ndim=4096,
|
||||
alpha=.5,
|
||||
encoder_dim=4096,
|
||||
encoder_path=None,
|
||||
**kwargs
|
||||
):
|
||||
@ -139,8 +138,7 @@ class GPTJRConfig(PretrainedConfig):
|
||||
self.bos_token_id = bos_token_id
|
||||
self.eos_token_id = eos_token_id
|
||||
|
||||
self.encoder_ndim = encoder_ndim
|
||||
self.alpha = alpha
|
||||
self.encoder_dim = encoder_dim
|
||||
self.encoder_path = encoder_path
|
||||
|
||||
super().__init__(
|
||||
|
@ -154,11 +154,6 @@ class GPTJRAttention(nn.Module):
|
||||
query_length, key_length = query.size(-2), key.size(-2)
|
||||
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
# TODO: do we need to do this with bfloat16??
|
||||
# query = query.to(torch.float32)
|
||||
# key = key.to(torch.float32)
|
||||
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
@ -188,7 +183,7 @@ class GPTJRAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
@ -291,13 +286,131 @@ class GPTJRCrossAttention(GPTJRAttention):
|
||||
)
|
||||
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(config.encoder_ndim, self.embed_dim, bias=False)
|
||||
self.k_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.rotary_dim = None
|
||||
if config.rotary_dim is not None:
|
||||
self.rotary_dim = config.rotary_dim
|
||||
|
||||
|
||||
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
|
||||
return tensor.permute(0, 2, 1)
|
||||
|
||||
|
||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def _attn(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
):
|
||||
|
||||
# compute causal mask from causal mask buffer
|
||||
# since key and value don't have seq length, just use causal mask as normal
|
||||
query_length = query.size(-2)
|
||||
causal_mask = self.bias[:, :, : query_length, :query_length].to(torch.bool)
|
||||
|
||||
# Keep the attention weights computation in fp32 to avoid overflow issues
|
||||
# TODO: do we need to do this with bfloat16??
|
||||
# query = query.to(torch.float32)
|
||||
# key = key.to(torch.float32)
|
||||
|
||||
# query -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim)
|
||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
||||
|
||||
mask_value = torch.finfo(attn_weights.dtype).min
|
||||
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
||||
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
||||
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
||||
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
||||
|
||||
attn_weights = attn_weights / self.scale_attn
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
# attn mask (1, 1, 1, seq_len)
|
||||
attn_weights = (attn_weights.permute(0, 2, 3, 1) + attention_mask).permute(0, 3, 1, 2)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
attn_weights = self.attn_dropout(attn_weights)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
# value -> (bs, num_attention_heads, head_dim)
|
||||
# attn_weights -> (bs, seq_len, num_attention_heads, num_attention_heads)
|
||||
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
attn_output = torch.matmul(attn_weights, value.transpose(-1, -2))
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||
]:
|
||||
|
||||
query = self.q_proj(hidden_states)
|
||||
# if we are doing cross attention
|
||||
key = self.k_proj(encoder_hidden_states)
|
||||
value = self.v_proj(encoder_hidden_states)
|
||||
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
query = self._split_heads(query, self.num_attention_heads, self.head_dim, False)
|
||||
# (bs, dim) -> (bs, num_attention_heads, head_dim)
|
||||
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
|
||||
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
||||
|
||||
|
||||
key = key.permute(0, 2, 1)
|
||||
query = query.permute(0, 2, 1, 3)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# compute self-attention: V x Softmax(QK^T)
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
|
||||
@ -330,9 +443,13 @@ class GPTJRBlock(nn.Module):
|
||||
|
||||
# TODO: fix for n neighbors
|
||||
# for SBERT this is 384
|
||||
self.ln_2 = nn.LayerNorm(config.encoder_ndim, eps=config.layer_norm_epsilon)
|
||||
self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon)
|
||||
self.cross_attn = GPTJRCrossAttention(config)
|
||||
self.alpha = config.alpha
|
||||
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
||||
|
||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype)
|
||||
self.step = 1
|
||||
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -361,31 +478,48 @@ class GPTJRBlock(nn.Module):
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
||||
|
||||
# encoder_hidden_states (bs, knn, encoder_dim)
|
||||
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||
encoder_normed = self.ln_2(encoder_hidden_states)
|
||||
# TODO: how do we handle neighbors
|
||||
|
||||
# TODO: we have to make sure we're doing masking right here
|
||||
# TODO: T5 passes query length to cross attention, do we need that?
|
||||
cross_attn_outputs = self.cross_attn(
|
||||
residual,
|
||||
encoder_hidden_states=encoder_normed,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
num_neighbors = encoder_normed.shape[1]
|
||||
cross_attn_outputs = []
|
||||
for k in range(num_neighbors):
|
||||
# cross_attn_outputs -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
residual,
|
||||
encoder_hidden_states=encoder_normed[:, k, :],
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
cross_attn_outputs.append(cross_attn_output[0])
|
||||
|
||||
cross_attn_output = torch.stack(cross_attn_outputs, dim=1).mean(dim=1)
|
||||
# gpt-j has parallel ff + attn, can do ff on encoder_normed too I guess?
|
||||
cross_attn_ff = self.cross_attn_mlp(
|
||||
cross_attn_output
|
||||
)
|
||||
cross_attn_output = cross_attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
cross_attn_outputs = cross_attn_outputs[1:]
|
||||
|
||||
hidden_states = self.alpha * cross_attn_output + (1 - self.alpha) * self_attention_residual
|
||||
|
||||
alpha = self.alpha if self.training else 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
# if training update alpha
|
||||
if self.training:
|
||||
self.step += 1
|
||||
self._update_alpha(self.step)
|
||||
|
||||
|
||||
return outputs # hidden_states, present, (attentions)
|
||||
|
||||
def _update_alpha(self, iteration):
|
||||
self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
|
||||
|
||||
|
||||
class GPTJRPreTrainedModel(PreTrainedModel):
|
||||
@ -769,7 +903,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_outputs[0],
|
||||
encoder_hidden_states=encoder_outputs,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
|
Loading…
Reference in New Issue
Block a user