From a69475a5f1871e9f35eb4658ffe1939f242754a1 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 15 May 2023 14:16:48 +0000 Subject: [PATCH] feat: updates to model + config --- .../pythiaseek/configuration_pythiaseek.py | 7 +- .../models/pythiaseek/modeling_pythiaseek.py | 95 +++++++++++-------- 2 files changed, 61 insertions(+), 41 deletions(-) diff --git a/gpt4all/models/pythiaseek/configuration_pythiaseek.py b/gpt4all/models/pythiaseek/configuration_pythiaseek.py index d08369cf..98004005 100644 --- a/gpt4all/models/pythiaseek/configuration_pythiaseek.py +++ b/gpt4all/models/pythiaseek/configuration_pythiaseek.py @@ -112,6 +112,8 @@ class PythiaSeekConfig(PretrainedConfig): total_alpha_steps=0, initial_alpha=1, final_alpha=.5, + cross_attn_layer=9, + learnable_alpha=False, **kwargs, ): super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) @@ -134,4 +136,7 @@ class PythiaSeekConfig(PretrainedConfig): self.encoder_dim = encoder_dim self.total_alpha_steps = total_alpha_steps self.initial_alpha = initial_alpha - self.final_alpha = final_alpha \ No newline at end of file + self.final_alpha = final_alpha + # index of cross attention layer to add + self.cross_attn_layer = cross_attn_layer + self.learnable_alpha = learnable_alpha \ No newline at end of file diff --git a/gpt4all/models/pythiaseek/modeling_pythiaseek.py b/gpt4all/models/pythiaseek/modeling_pythiaseek.py index db7c0661..ef95a3f6 100644 --- a/gpt4all/models/pythiaseek/modeling_pythiaseek.py +++ b/gpt4all/models/pythiaseek/modeling_pythiaseek.py @@ -17,6 +17,7 @@ from typing import Optional, Tuple, Union import math +import torch.nn.functional as F import torch import torch.utils.checkpoint from torch import nn @@ -50,7 +51,7 @@ class PythiaSeekPreTrainedModel(PreTrainedModel): """ config_class = PythiaSeekConfig - base_model_prefix = "pythiaseek" + base_model_prefix = "gpt_neox" supports_gradient_checkpointing = True _no_split_modules = ["PythiaSeekLayer"] @@ -237,15 +238,6 @@ class PythiaSeekCrossAttention(PythiaSeekAttention): def __init__(self, config): super().__init__(config) - max_positions = config.max_position_embeddings - self.register_buffer( - "bias", - torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( - 1, 1, max_positions, max_positions - ), - ) - self.register_buffer("masked_bias", torch.tensor(-1e9)) - self.embed_dim = config.hidden_size self.num_attention_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_attention_heads @@ -425,7 +417,7 @@ class PythiaSeekMLP(nn.Module): class PythiaSeekLayer(nn.Module): - def __init__(self, config): + def __init__(self, config, add_cross_attention=False): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -433,13 +425,19 @@ class PythiaSeekLayer(nn.Module): self.attention = PythiaSeekAttention(config) self.mlp = PythiaSeekMLP(config) - self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.cross_attn = PythiaSeekCrossAttention(config) - self.cross_attn_mlp = PythiaSeekMLP(config) - - self.total_alpha_steps = config.total_alpha_steps - self.initial_alpha = config.initial_alpha - self.final_alpha = config.final_alpha + self.add_cross_attention = add_cross_attention + + if add_cross_attention: + self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.cross_attn = PythiaSeekCrossAttention(config) + self.cross_attn_mlp = PythiaSeekMLP(config) + + self.total_alpha_steps = config.total_alpha_steps + self.initial_alpha = config.initial_alpha + self.final_alpha = config.final_alpha + self.learnable_alpha = config.learnable_alpha + if self.learnable_alpha: + self.alpha = nn.Parameter(torch.zeros(1), requires_grad=True) def forward( self, @@ -483,28 +481,34 @@ class PythiaSeekLayer(nn.Module): if encoder_hidden_states.dtype != ln_hidden_states.dtype: encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype) - encoder_normed = self.cross_attn_ln(encoder_hidden_states) + if self.add_cross_attention: + encoder_normed = self.cross_attn_ln(encoder_hidden_states) - # cross_attn_outputs -> (bs, seq_len, dim) - cross_attn_output = self.cross_attn( - ln_hidden_states, - encoder_hidden_states=encoder_normed, - attention_mask=attention_mask, - head_mask=head_mask, - use_cache=use_cache, - output_attentions=output_attentions, - ) + # cross_attn_outputs -> (bs, seq_len, dim) + cross_attn_output = self.cross_attn( + ln_hidden_states, + encoder_hidden_states=encoder_normed, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) - cross_attn_ff = self.cross_attn_mlp( - cross_attn_output[0] - ) + cross_attn_ff = self.cross_attn_mlp( + cross_attn_output[0] + ) - if step is not None: - alpha = self._update_alpha(step) + if step is not None: + if self.learnable_alpha: + alpha = F.sigmoid(self.alpha) + else: + alpha = self._update_alpha(step) + else: + alpha = 0.5 + + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual else: - alpha = 0.5 - - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual + hidden_states = self_attention_residual if use_cache: outputs = (hidden_states,) + outputs @@ -547,7 +551,16 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel): self.config = config self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([PythiaSeekLayer(config) for _ in range(config.num_hidden_layers)]) + + if isinstance(config.cross_attn_layer, int): + cross_attn_layers = [config.cross_attn_layer] + # if we add cross attention to all layers + elif config.cross_attn_layer is None: + cross_attn_layers = list(range(config.num_hidden_layers)) + else: + cross_attn_layers = config.cross_attn_layer + + self.layers = nn.ModuleList([PythiaSeekLayer(config, add_cross_attention=i in cross_attn_layers) for i in range(config.num_hidden_layers)]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False @@ -719,7 +732,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): def __init__(self, config): super().__init__(config) - self.pythiaseek = PythiaSeekModel(config) + self.gpt_neox = PythiaSeekModel(config) self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.hidden_size = config.hidden_size @@ -727,7 +740,9 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): if self.hidden_size != self.encoder_dim: self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4), - nn.Linear(config.hidden_size * 4, config.hidden_size)) + nn.ReLU(), + nn.Linear(config.hidden_size * 4, config.hidden_size), + ) # Initialize weights and apply final processing self.post_init() @@ -799,7 +814,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype) encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states) - outputs = self.pythiaseek( + outputs = self.gpt_neox( input_ids, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,