From c9dd9152c3fa4cae1049a3bfebd9a4a4b5427665 Mon Sep 17 00:00:00 2001 From: Zach Nussbaum Date: Mon, 1 May 2023 21:38:36 +0000 Subject: [PATCH] feat: model def + metrics --- gpt4all/models/configuration_gpt_jr.py | 1 - gpt4all/models/modeling_gpt_jr.py | 56 +++++++++----------------- gpt4all/train/metrics.py | 50 +++++++++++++++++++++++ 3 files changed, 68 insertions(+), 39 deletions(-) create mode 100644 gpt4all/train/metrics.py diff --git a/gpt4all/models/configuration_gpt_jr.py b/gpt4all/models/configuration_gpt_jr.py index fdb9eec4..314a87af 100644 --- a/gpt4all/models/configuration_gpt_jr.py +++ b/gpt4all/models/configuration_gpt_jr.py @@ -139,7 +139,6 @@ class GPTJRConfig(PretrainedConfig): self.eos_token_id = eos_token_id self.encoder_dim = encoder_dim - self.encoder_path = encoder_path super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index a38ad5b0..36e053ea 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -423,8 +423,6 @@ class GPTJRBlock(nn.Module): self.cross_attn = GPTJRCrossAttention(config) self.cross_attn_mlp = GPTJRMLP(inner_dim, config) - self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype) - self.step = 1 self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1 def forward( @@ -436,6 +434,7 @@ class GPTJRBlock(nn.Module): head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, + step: Optional[int] = None, ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: # shape (bs, seq_len, hidden_dim) residual = hidden_states @@ -455,7 +454,8 @@ class GPTJRBlock(nn.Module): self_attention_residual = attn_output + feed_forward_hidden_states + residual # encoder_hidden_states -> (bs, knn, encoder_dim) - # may not need, can norm encodings + if encoder_hidden_states.dtype != hidden_states.dtype: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype) encoder_normed = self.ln_2(encoder_hidden_states) # cross_attn_outputs -> (bs, seq_len, dim) @@ -473,24 +473,22 @@ class GPTJRBlock(nn.Module): cross_attn_output[0] ) - alpha = self.alpha if self.training else 0.5 - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual + if step is not None: + alpha = self._update_alpha(step) + alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) + else: + alpha = 0.5 + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual if use_cache: outputs = (hidden_states,) + outputs else: outputs = (hidden_states,) + outputs[1:] - # if training update alpha - if self.training: - self.step += 1 - self._update_alpha(self.step) - - return outputs # hidden_states, present, (attentions) def _update_alpha(self, iteration): - self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) + return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) class GPTJRPreTrainedModel(PreTrainedModel): @@ -597,6 +595,7 @@ class GPTJRModel(GPTJRPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -704,7 +703,7 @@ class GPTJRModel(GPTJRPreTrainedModel): def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, use_cache, output_attentions) + return module(*inputs, use_cache, output_attentions, step) return custom_forward @@ -725,6 +724,7 @@ class GPTJRModel(GPTJRPreTrainedModel): head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, + step=step ) hidden_states = outputs[0] @@ -765,11 +765,6 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): super().__init__(config) self.transformer = GPTJRModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size) - if config.encoder_path is not None: - self.encoder = AutoModel.from_pretrained(config.encoder_path) - # freeze encoder and don't get gradiets - self.encoder.requires_grad_(False) - # Model parallel self.model_parallel = False @@ -832,23 +827,20 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): def forward( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, + encoder_hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, - decoder_input_ids: Optional[torch.FloatTensor] = None, - decoder_attention_mask: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - decoder_head_mask: Optional[torch.FloatTensor] = None, - cross_attn_head_mask: Optional[torch.FloatTensor] = None, - encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + step: Optional[int] = None ) -> Union[Tuple, CausalLMOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -859,22 +851,9 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # Encode if needed (training, first prediction pass) - if encoder_outputs is None: - # Convert encoder inputs in embeddings if needed - encoder_outputs = self.encoder( - input_ids=input_ids, - attention_mask=attention_mask, - inputs_embeds=inputs_embeds, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - transformer_outputs = self.transformer( input_ids, - encoder_hidden_states=encoder_outputs, + encoder_hidden_states=encoder_hidden_states, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -885,6 +864,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + step=step, ) hidden_states = transformer_outputs[0] diff --git a/gpt4all/train/metrics.py b/gpt4all/train/metrics.py new file mode 100644 index 00000000..dac602d8 --- /dev/null +++ b/gpt4all/train/metrics.py @@ -0,0 +1,50 @@ +from collections import Counter +import string +import re + + +# adapted from huggingface +def normalize_answer(s): + """Lower text and remove punctuation, articles and extra whitespace.""" + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + +def f1_score(predictions, ground_truths): + total_f1 = [] + for prediction, ground_truth in zip(predictions, ground_truths): + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + total_f1.append(0) + continue + + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) + total_f1.append(f1) + + return total_f1 + + +def exact_match_score(predictions, ground_truths): + exact_scores = [] + for prediction, ground_truth in zip(predictions, ground_truths): + exact_scores.append(normalize_answer(prediction) == normalize_answer(ground_truth)) + + return exact_scores