mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 12:06:54 +00:00
feat: updates to model + config
This commit is contained in:
parent
4d5e48a7a9
commit
a69475a5f1
@ -112,6 +112,8 @@ class PythiaSeekConfig(PretrainedConfig):
|
||||
total_alpha_steps=0,
|
||||
initial_alpha=1,
|
||||
final_alpha=.5,
|
||||
cross_attn_layer=9,
|
||||
learnable_alpha=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
@ -134,4 +136,7 @@ class PythiaSeekConfig(PretrainedConfig):
|
||||
self.encoder_dim = encoder_dim
|
||||
self.total_alpha_steps = total_alpha_steps
|
||||
self.initial_alpha = initial_alpha
|
||||
self.final_alpha = final_alpha
|
||||
self.final_alpha = final_alpha
|
||||
# index of cross attention layer to add
|
||||
self.cross_attn_layer = cross_attn_layer
|
||||
self.learnable_alpha = learnable_alpha
|
@ -17,6 +17,7 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import math
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
@ -50,7 +51,7 @@ class PythiaSeekPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = PythiaSeekConfig
|
||||
base_model_prefix = "pythiaseek"
|
||||
base_model_prefix = "gpt_neox"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["PythiaSeekLayer"]
|
||||
|
||||
@ -237,15 +238,6 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_attention_heads
|
||||
@ -425,7 +417,7 @@ class PythiaSeekMLP(nn.Module):
|
||||
|
||||
|
||||
class PythiaSeekLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, add_cross_attention=False):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = config.use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@ -433,13 +425,19 @@ class PythiaSeekLayer(nn.Module):
|
||||
self.attention = PythiaSeekAttention(config)
|
||||
self.mlp = PythiaSeekMLP(config)
|
||||
|
||||
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cross_attn = PythiaSeekCrossAttention(config)
|
||||
self.cross_attn_mlp = PythiaSeekMLP(config)
|
||||
|
||||
self.total_alpha_steps = config.total_alpha_steps
|
||||
self.initial_alpha = config.initial_alpha
|
||||
self.final_alpha = config.final_alpha
|
||||
self.add_cross_attention = add_cross_attention
|
||||
|
||||
if add_cross_attention:
|
||||
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cross_attn = PythiaSeekCrossAttention(config)
|
||||
self.cross_attn_mlp = PythiaSeekMLP(config)
|
||||
|
||||
self.total_alpha_steps = config.total_alpha_steps
|
||||
self.initial_alpha = config.initial_alpha
|
||||
self.final_alpha = config.final_alpha
|
||||
self.learnable_alpha = config.learnable_alpha
|
||||
if self.learnable_alpha:
|
||||
self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -483,28 +481,34 @@ class PythiaSeekLayer(nn.Module):
|
||||
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
|
||||
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
|
||||
|
||||
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
||||
if self.add_cross_attention:
|
||||
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
||||
|
||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
ln_hidden_states,
|
||||
encoder_hidden_states=encoder_normed,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
ln_hidden_states,
|
||||
encoder_hidden_states=encoder_normed,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
cross_attn_ff = self.cross_attn_mlp(
|
||||
cross_attn_output[0]
|
||||
)
|
||||
cross_attn_ff = self.cross_attn_mlp(
|
||||
cross_attn_output[0]
|
||||
)
|
||||
|
||||
if step is not None:
|
||||
alpha = self._update_alpha(step)
|
||||
if step is not None:
|
||||
if self.learnable_alpha:
|
||||
alpha = F.sigmoid(self.alpha)
|
||||
else:
|
||||
alpha = self._update_alpha(step)
|
||||
else:
|
||||
alpha = 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
else:
|
||||
alpha = 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
hidden_states = self_attention_residual
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
@ -547,7 +551,16 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||
self.config = config
|
||||
|
||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([PythiaSeekLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
if isinstance(config.cross_attn_layer, int):
|
||||
cross_attn_layers = [config.cross_attn_layer]
|
||||
# if we add cross attention to all layers
|
||||
elif config.cross_attn_layer is None:
|
||||
cross_attn_layers = list(range(config.num_hidden_layers))
|
||||
else:
|
||||
cross_attn_layers = config.cross_attn_layer
|
||||
|
||||
self.layers = nn.ModuleList([PythiaSeekLayer(config, add_cross_attention=i in cross_attn_layers) for i in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@ -719,7 +732,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.pythiaseek = PythiaSeekModel(config)
|
||||
self.gpt_neox = PythiaSeekModel(config)
|
||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -727,7 +740,9 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
|
||||
if self.hidden_size != self.encoder_dim:
|
||||
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
|
||||
nn.Linear(config.hidden_size * 4, config.hidden_size))
|
||||
nn.ReLU(),
|
||||
nn.Linear(config.hidden_size * 4, config.hidden_size),
|
||||
)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
@ -799,7 +814,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||
|
||||
outputs = self.pythiaseek(
|
||||
outputs = self.gpt_neox(
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
|
Loading…
Reference in New Issue
Block a user