diff --git a/gpt4all/models/pythiaseek/configuration_pythiaseek.py b/gpt4all/models/pythiaseek/configuration_pythiaseek.py index 5960627f..d08369cf 100644 --- a/gpt4all/models/pythiaseek/configuration_pythiaseek.py +++ b/gpt4all/models/pythiaseek/configuration_pythiaseek.py @@ -129,4 +129,9 @@ class PythiaSeekConfig(PretrainedConfig): self.layer_norm_eps = layer_norm_eps self.use_cache = use_cache self.tie_word_embeddings = tie_word_embeddings - self.use_parallel_residual = use_parallel_residual \ No newline at end of file + self.use_parallel_residual = use_parallel_residual + + self.encoder_dim = encoder_dim + self.total_alpha_steps = total_alpha_steps + self.initial_alpha = initial_alpha + self.final_alpha = final_alpha \ No newline at end of file diff --git a/gpt4all/models/pythiaseek/modeling_pythiaseek.py b/gpt4all/models/pythiaseek/modeling_pythiaseek.py index 473f58c6..db7c0661 100644 --- a/gpt4all/models/pythiaseek/modeling_pythiaseek.py +++ b/gpt4all/models/pythiaseek/modeling_pythiaseek.py @@ -233,6 +233,131 @@ class PythiaSeekAttention(nn.Module): attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights +class PythiaSeekCrossAttention(PythiaSeekAttention): + def __init__(self, config): + super().__init__(config) + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), + ) + self.register_buffer("masked_bias", torch.tensor(-1e9)) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and" + f" `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + + + def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size): + new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) + tensor = tensor.view(new_shape) + + return tensor.permute(0, 1, 3, 2) + + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden dim + """ + # tensor -> (bs, seq_len, num_attention_heads, head_dim) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # query -> (bs, num_attention_heads, seq_len, head_dim) + # key -> (bs, num_attention_heads, head_dim, neighbors) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) + attn_weights = torch.matmul(query, key) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = attn_weights.to(value.dtype) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + # value -> (bs, num_attention_heads, seq_len, head_dim) + # attn_weights -> (bs, num_attention_heads, seq_len, neighbors) + # attn_output -> (bs, num_attention_heads, seq_len, head_dim) + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + + query = self.q_proj(hidden_states) + # if we are doing cross attention + key = self.k_proj(encoder_hidden_states) + value = self.v_proj(encoder_hidden_states) + # (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim) + query = self._split_heads(query, self.num_attention_heads, self.head_dim) + # (bs, dim) -> (bs, num_attention_heads, head_dim) + key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim) + value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim) + + + value = value.permute(0, 3, 1, 2) + key = key.permute(0, 3, 2, 1) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + def attention_mask_func(attention_scores, ltor_mask): attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) @@ -308,18 +433,29 @@ class PythiaSeekLayer(nn.Module): self.attention = PythiaSeekAttention(config) self.mlp = PythiaSeekMLP(config) + self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.cross_attn = PythiaSeekCrossAttention(config) + self.cross_attn_mlp = PythiaSeekMLP(config) + + self.total_alpha_steps = config.total_alpha_steps + self.initial_alpha = config.initial_alpha + self.final_alpha = config.final_alpha + def forward( self, hidden_states: Optional[torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, layer_past: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, + step: Optional[int] = None, ): + ln_hidden_states = self.input_layernorm(hidden_states) attention_layer_outputs = self.attention( - self.input_layernorm(hidden_states), + ln_hidden_states, attention_mask=attention_mask, position_ids=position_ids, layer_past=layer_past, @@ -334,74 +470,75 @@ class PythiaSeekLayer(nn.Module): # pseudocode: # x = x + attn(ln1(x)) + mlp(ln2(x)) mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) - hidden_states = mlp_output + attn_output + hidden_states + self_attention_residual = mlp_output + attn_output + hidden_states else: # pseudocode: # x = x + attn(ln1(x)) # x = x + mlp(ln2(x)) attn_output = attn_output + hidden_states mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) - hidden_states = mlp_output + attn_output + self_attention_residual = mlp_output + attn_output + + # encoder_hidden_states -> (bs, knn, encoder_dim) + if encoder_hidden_states.dtype != ln_hidden_states.dtype: + encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype) + + encoder_normed = self.cross_attn_ln(encoder_hidden_states) + + # cross_attn_outputs -> (bs, seq_len, dim) + cross_attn_output = self.cross_attn( + ln_hidden_states, + encoder_hidden_states=encoder_normed, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + cross_attn_ff = self.cross_attn_mlp( + cross_attn_output[0] + ) + + if step is not None: + alpha = self._update_alpha(step) + else: + alpha = 0.5 + + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: - outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + outputs = (hidden_states,) + outputs else: - outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + outputs = (hidden_states,) + outputs[1:] - return outputs + return outputs # hidden_states, present, (attentions) + + def _update_alpha(self, current_step): + """ + Computes the learning rate for the current step using a cosine decay schedule. + Args: + initial_lr (float): The initial learning rate. + final_lr (float): The final learning rate. + total_steps (int): The total number of steps in the schedule. + current_step (int): The current step. -GPT_NEOX_START_DOCSTRING = r""" - This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use - it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and - behavior. + Returns: + float: The learning rate for the current step. + """ + initial_alpha = 1 + final_alpha = .5 + if current_step >= self.total_alpha_steps: + return final_alpha - Parameters: - config ([`~PythiaSeekConfig`]): Model configuration class with all the parameters of the model. - Initializing with a config file does not load the weights associated with the model, only the - configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. -""" + # Compute the cosine decay factor + cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps)) -GPT_NEOX_INPUTS_DOCSTRING = r""" - Args: - input_ids (`torch.LongTensor` of shape `({0})`): - Indices of input sequence tokens in the vocabulary. + # Compute the current learning rate + alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and - [`PreTrainedTokenizer.__call__`] for details. + return alpha - [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - [What are attention masks?](../glossary#attention-mask) - position_ids (`torch.LongTensor` of shape `({0})`, *optional*): - Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, - config.n_positions - 1]`. - - [What are position IDs?](../glossary#position-ids) - head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: - - - 1 indicates the head is **not masked**, - - 0 indicates the head is **masked**. - - inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*): - Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This - is useful if you want more control over how to convert *input_ids* indices into associated vectors than the - model's internal embedding lookup matrix. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for - more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple. -""" class PythiaSeekModel(PythiaSeekPreTrainedModel): @@ -427,6 +564,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -436,6 +574,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None ) -> Union[Tuple, BaseModelOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): @@ -528,26 +667,29 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for layer_past - return module(*inputs, use_cache, None, output_attentions) + return module(*inputs, use_cache, None, output_attentions, step) return custom_forward outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), hidden_states, + encoder_hidden_states, attention_mask, position_ids, head_mask[i], ) else: outputs = layer( - hidden_states, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], layer_past=layer_past, use_cache=use_cache, output_attentions=output_attentions, + step=step, ) hidden_states = outputs[0] if use_cache is True: @@ -577,9 +719,16 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): def __init__(self, config): super().__init__(config) - self.gpt_neox = PythiaSeekModel(config) + self.pythiaseek = PythiaSeekModel(config) self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.hidden_size = config.hidden_size + self.encoder_dim = config.encoder_dim + + if self.hidden_size != self.encoder_dim: + self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4), + nn.Linear(config.hidden_size * 4, config.hidden_size)) + # Initialize weights and apply final processing self.post_init() @@ -591,17 +740,19 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, + encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, position_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None ) -> Union[Tuple, CausalLMOutputWithPast]: r""" past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): @@ -644,8 +795,13 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs = self.gpt_neox( + if self.hidden_size != self.encoder_dim: + encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype) + encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states) + + outputs = self.pythiaseek( input_ids, + encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, @@ -655,6 +811,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + step=step, ) hidden_states = outputs[0] diff --git a/gpt4all/models/pythiaseek/test_pythiaseek.py b/gpt4all/models/pythiaseek/test_pythiaseek.py new file mode 100644 index 00000000..2d90c984 --- /dev/null +++ b/gpt4all/models/pythiaseek/test_pythiaseek.py @@ -0,0 +1,55 @@ +import torch +from gpt4all.models import PythiaSeekConfig, PythiaSeekForCausalLM +from transformers import AutoTokenizer, AutoModel + +# seed torch + +torch.manual_seed(0) + +config = PythiaSeekConfig(encoder_dim=384, + hidden_size=1024, + intermediate_size=4096, + num_hidden_layers=4, + num_attention_heads=4) +print("loaded config") + +print("loading model") +model = PythiaSeekForCausalLM(config) +print("loaded model") + + +tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b") +tokenizer.pad_token = tokenizer.eos_token +tokenizer.model_max_length = 1024 + +encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') +encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + + +def mean_pooling(model_output, attention_mask): + token_embeddings = model_output[0] #First element of model_output contains all token embeddings + input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() + return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) + +text = "The quick brown fox jumps over the lazy dog." +print("Encoded knn") +tokenized = encoder_tokenizer(text, return_tensors="pt") + +# bs, seq_len, dim +encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"]) + +# make 2 neighbors +# (bs, knn, encoding_dim) +encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0) + +inputs = "What did the fox do?" + +print("Encoded inputs") +tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt") + +print("Running model") +outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states) + +print(outputs) +print(outputs[0].shape) +