feat: model def + metrics

This commit is contained in:
Zach Nussbaum 2023-05-01 21:38:36 +00:00
parent 48e07be9e9
commit c9dd9152c3
3 changed files with 68 additions and 39 deletions

View File

@ -139,7 +139,6 @@ class GPTJRConfig(PretrainedConfig):
self.eos_token_id = eos_token_id
self.encoder_dim = encoder_dim
self.encoder_path = encoder_path
super().__init__(
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs

View File

@ -423,8 +423,6 @@ class GPTJRBlock(nn.Module):
self.cross_attn = GPTJRCrossAttention(config)
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype)
self.step = 1
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
def forward(
@ -436,6 +434,7 @@ class GPTJRBlock(nn.Module):
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
step: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
# shape (bs, seq_len, hidden_dim)
residual = hidden_states
@ -455,7 +454,8 @@ class GPTJRBlock(nn.Module):
self_attention_residual = attn_output + feed_forward_hidden_states + residual
# encoder_hidden_states -> (bs, knn, encoder_dim)
# may not need, can norm encodings
if encoder_hidden_states.dtype != hidden_states.dtype:
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
encoder_normed = self.ln_2(encoder_hidden_states)
# cross_attn_outputs -> (bs, seq_len, dim)
@ -473,24 +473,22 @@ class GPTJRBlock(nn.Module):
cross_attn_output[0]
)
alpha = self.alpha if self.training else 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if step is not None:
alpha = self._update_alpha(step)
alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
# if training update alpha
if self.training:
self.step += 1
self._update_alpha(self.step)
return outputs # hidden_states, present, (attentions)
def _update_alpha(self, iteration):
self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
class GPTJRPreTrainedModel(PreTrainedModel):
@ -597,6 +595,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@ -704,7 +703,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return module(*inputs, use_cache, output_attentions, step)
return custom_forward
@ -725,6 +724,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
step=step
)
hidden_states = outputs[0]
@ -765,11 +765,6 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
super().__init__(config)
self.transformer = GPTJRModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
if config.encoder_path is not None:
self.encoder = AutoModel.from_pretrained(config.encoder_path)
# freeze encoder and don't get gradiets
self.encoder.requires_grad_(False)
# Model parallel
self.model_parallel = False
@ -832,23 +827,20 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask: Optional[torch.FloatTensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -859,22 +851,9 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
transformer_outputs = self.transformer(
input_ids,
encoder_hidden_states=encoder_outputs,
encoder_hidden_states=encoder_hidden_states,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@ -885,6 +864,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
step=step,
)
hidden_states = transformer_outputs[0]

50
gpt4all/train/metrics.py Normal file
View File

@ -0,0 +1,50 @@
from collections import Counter
import string
import re
# adapted from huggingface
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(predictions, ground_truths):
total_f1 = []
for prediction, ground_truth in zip(predictions, ground_truths):
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
total_f1.append(0)
continue
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
total_f1.append(f1)
return total_f1
def exact_match_score(predictions, ground_truths):
exact_scores = []
for prediction, ground_truth in zip(predictions, ground_truths):
exact_scores.append(normalize_answer(prediction) == normalize_answer(ground_truth))
return exact_scores