feat: updates to model + config

This commit is contained in:
Zach Nussbaum 2023-05-15 14:16:48 +00:00
parent 4d5e48a7a9
commit a69475a5f1
2 changed files with 61 additions and 41 deletions

View File

@ -112,6 +112,8 @@ class PythiaSeekConfig(PretrainedConfig):
total_alpha_steps=0,
initial_alpha=1,
final_alpha=.5,
cross_attn_layer=9,
learnable_alpha=False,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@ -134,4 +136,7 @@ class PythiaSeekConfig(PretrainedConfig):
self.encoder_dim = encoder_dim
self.total_alpha_steps = total_alpha_steps
self.initial_alpha = initial_alpha
self.final_alpha = final_alpha
self.final_alpha = final_alpha
# index of cross attention layer to add
self.cross_attn_layer = cross_attn_layer
self.learnable_alpha = learnable_alpha

View File

@ -17,6 +17,7 @@
from typing import Optional, Tuple, Union
import math
import torch.nn.functional as F
import torch
import torch.utils.checkpoint
from torch import nn
@ -50,7 +51,7 @@ class PythiaSeekPreTrainedModel(PreTrainedModel):
"""
config_class = PythiaSeekConfig
base_model_prefix = "pythiaseek"
base_model_prefix = "gpt_neox"
supports_gradient_checkpointing = True
_no_split_modules = ["PythiaSeekLayer"]
@ -237,15 +238,6 @@ class PythiaSeekCrossAttention(PythiaSeekAttention):
def __init__(self, config):
super().__init__(config)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
@ -425,7 +417,7 @@ class PythiaSeekMLP(nn.Module):
class PythiaSeekLayer(nn.Module):
def __init__(self, config):
def __init__(self, config, add_cross_attention=False):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@ -433,13 +425,19 @@ class PythiaSeekLayer(nn.Module):
self.attention = PythiaSeekAttention(config)
self.mlp = PythiaSeekMLP(config)
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = PythiaSeekCrossAttention(config)
self.cross_attn_mlp = PythiaSeekMLP(config)
self.total_alpha_steps = config.total_alpha_steps
self.initial_alpha = config.initial_alpha
self.final_alpha = config.final_alpha
self.add_cross_attention = add_cross_attention
if add_cross_attention:
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = PythiaSeekCrossAttention(config)
self.cross_attn_mlp = PythiaSeekMLP(config)
self.total_alpha_steps = config.total_alpha_steps
self.initial_alpha = config.initial_alpha
self.final_alpha = config.final_alpha
self.learnable_alpha = config.learnable_alpha
if self.learnable_alpha:
self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True)
def forward(
self,
@ -483,28 +481,34 @@ class PythiaSeekLayer(nn.Module):
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
if self.add_cross_attention:
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
ln_hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
ln_hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output[0]
)
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output[0]
)
if step is not None:
alpha = self._update_alpha(step)
if step is not None:
if self.learnable_alpha:
alpha = F.sigmoid(self.alpha)
else:
alpha = self._update_alpha(step)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
hidden_states = self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
@ -547,7 +551,16 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
self.config = config
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([PythiaSeekLayer(config) for _ in range(config.num_hidden_layers)])
if isinstance(config.cross_attn_layer, int):
cross_attn_layers = [config.cross_attn_layer]
# if we add cross attention to all layers
elif config.cross_attn_layer is None:
cross_attn_layers = list(range(config.num_hidden_layers))
else:
cross_attn_layers = config.cross_attn_layer
self.layers = nn.ModuleList([PythiaSeekLayer(config, add_cross_attention=i in cross_attn_layers) for i in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
@ -719,7 +732,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.pythiaseek = PythiaSeekModel(config)
self.gpt_neox = PythiaSeekModel(config)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hidden_size = config.hidden_size
@ -727,7 +740,9 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
if self.hidden_size != self.encoder_dim:
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
nn.Linear(config.hidden_size * 4, config.hidden_size))
nn.ReLU(),
nn.Linear(config.hidden_size * 4, config.hidden_size),
)
# Initialize weights and apply final processing
self.post_init()
@ -799,7 +814,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
outputs = self.pythiaseek(
outputs = self.gpt_neox(
input_ids,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,