mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 04:20:59 +00:00
feat: updates to model + config
This commit is contained in:
parent
4d5e48a7a9
commit
a69475a5f1
@ -112,6 +112,8 @@ class PythiaSeekConfig(PretrainedConfig):
|
|||||||
total_alpha_steps=0,
|
total_alpha_steps=0,
|
||||||
initial_alpha=1,
|
initial_alpha=1,
|
||||||
final_alpha=.5,
|
final_alpha=.5,
|
||||||
|
cross_attn_layer=9,
|
||||||
|
learnable_alpha=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||||
@ -135,3 +137,6 @@ class PythiaSeekConfig(PretrainedConfig):
|
|||||||
self.total_alpha_steps = total_alpha_steps
|
self.total_alpha_steps = total_alpha_steps
|
||||||
self.initial_alpha = initial_alpha
|
self.initial_alpha = initial_alpha
|
||||||
self.final_alpha = final_alpha
|
self.final_alpha = final_alpha
|
||||||
|
# index of cross attention layer to add
|
||||||
|
self.cross_attn_layer = cross_attn_layer
|
||||||
|
self.learnable_alpha = learnable_alpha
|
@ -17,6 +17,7 @@
|
|||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import torch.nn.functional as F
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -50,7 +51,7 @@ class PythiaSeekPreTrainedModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
config_class = PythiaSeekConfig
|
config_class = PythiaSeekConfig
|
||||||
base_model_prefix = "pythiaseek"
|
base_model_prefix = "gpt_neox"
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
_no_split_modules = ["PythiaSeekLayer"]
|
_no_split_modules = ["PythiaSeekLayer"]
|
||||||
|
|
||||||
@ -237,15 +238,6 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
max_positions = config.max_position_embeddings
|
|
||||||
self.register_buffer(
|
|
||||||
"bias",
|
|
||||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
|
||||||
1, 1, max_positions, max_positions
|
|
||||||
),
|
|
||||||
)
|
|
||||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
|
||||||
|
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
self.head_dim = self.embed_dim // self.num_attention_heads
|
self.head_dim = self.embed_dim // self.num_attention_heads
|
||||||
@ -425,7 +417,7 @@ class PythiaSeekMLP(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class PythiaSeekLayer(nn.Module):
|
class PythiaSeekLayer(nn.Module):
|
||||||
def __init__(self, config):
|
def __init__(self, config, add_cross_attention=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = config.use_parallel_residual
|
self.use_parallel_residual = config.use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
@ -433,6 +425,9 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
self.attention = PythiaSeekAttention(config)
|
self.attention = PythiaSeekAttention(config)
|
||||||
self.mlp = PythiaSeekMLP(config)
|
self.mlp = PythiaSeekMLP(config)
|
||||||
|
|
||||||
|
self.add_cross_attention = add_cross_attention
|
||||||
|
|
||||||
|
if add_cross_attention:
|
||||||
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
self.cross_attn = PythiaSeekCrossAttention(config)
|
self.cross_attn = PythiaSeekCrossAttention(config)
|
||||||
self.cross_attn_mlp = PythiaSeekMLP(config)
|
self.cross_attn_mlp = PythiaSeekMLP(config)
|
||||||
@ -440,6 +435,9 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
self.total_alpha_steps = config.total_alpha_steps
|
self.total_alpha_steps = config.total_alpha_steps
|
||||||
self.initial_alpha = config.initial_alpha
|
self.initial_alpha = config.initial_alpha
|
||||||
self.final_alpha = config.final_alpha
|
self.final_alpha = config.final_alpha
|
||||||
|
self.learnable_alpha = config.learnable_alpha
|
||||||
|
if self.learnable_alpha:
|
||||||
|
self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -483,6 +481,7 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
|
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
|
||||||
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
|
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
|
||||||
|
|
||||||
|
if self.add_cross_attention:
|
||||||
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
||||||
|
|
||||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||||
@ -500,11 +499,16 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if step is not None:
|
if step is not None:
|
||||||
|
if self.learnable_alpha:
|
||||||
|
alpha = F.sigmoid(self.alpha)
|
||||||
|
else:
|
||||||
alpha = self._update_alpha(step)
|
alpha = self._update_alpha(step)
|
||||||
else:
|
else:
|
||||||
alpha = 0.5
|
alpha = 0.5
|
||||||
|
|
||||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||||
|
else:
|
||||||
|
hidden_states = self_attention_residual
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
outputs = (hidden_states,) + outputs
|
outputs = (hidden_states,) + outputs
|
||||||
@ -547,7 +551,16 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||||
self.layers = nn.ModuleList([PythiaSeekLayer(config) for _ in range(config.num_hidden_layers)])
|
|
||||||
|
if isinstance(config.cross_attn_layer, int):
|
||||||
|
cross_attn_layers = [config.cross_attn_layer]
|
||||||
|
# if we add cross attention to all layers
|
||||||
|
elif config.cross_attn_layer is None:
|
||||||
|
cross_attn_layers = list(range(config.num_hidden_layers))
|
||||||
|
else:
|
||||||
|
cross_attn_layers = config.cross_attn_layer
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([PythiaSeekLayer(config, add_cross_attention=i in cross_attn_layers) for i in range(config.num_hidden_layers)])
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
@ -719,7 +732,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.pythiaseek = PythiaSeekModel(config)
|
self.gpt_neox = PythiaSeekModel(config)
|
||||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@ -727,7 +740,9 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
|
|
||||||
if self.hidden_size != self.encoder_dim:
|
if self.hidden_size != self.encoder_dim:
|
||||||
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
|
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
|
||||||
nn.Linear(config.hidden_size * 4, config.hidden_size))
|
nn.ReLU(),
|
||||||
|
nn.Linear(config.hidden_size * 4, config.hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
@ -799,7 +814,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||||
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
outputs = self.pythiaseek(
|
outputs = self.gpt_neox(
|
||||||
input_ids,
|
input_ids,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
Loading…
Reference in New Issue
Block a user