mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 04:20:59 +00:00
feat: model def + metrics
This commit is contained in:
parent
48e07be9e9
commit
c9dd9152c3
@ -139,7 +139,6 @@ class GPTJRConfig(PretrainedConfig):
|
|||||||
self.eos_token_id = eos_token_id
|
self.eos_token_id = eos_token_id
|
||||||
|
|
||||||
self.encoder_dim = encoder_dim
|
self.encoder_dim = encoder_dim
|
||||||
self.encoder_path = encoder_path
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
|
||||||
|
@ -423,8 +423,6 @@ class GPTJRBlock(nn.Module):
|
|||||||
self.cross_attn = GPTJRCrossAttention(config)
|
self.cross_attn = GPTJRCrossAttention(config)
|
||||||
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
self.cross_attn_mlp = GPTJRMLP(inner_dim, config)
|
||||||
|
|
||||||
self.alpha = nn.Parameter(torch.ones(1), requires_grad=False).to(self.ln_1.weight.dtype)
|
|
||||||
self.step = 1
|
|
||||||
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
|
self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -436,6 +434,7 @@ class GPTJRBlock(nn.Module):
|
|||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
step: Optional[int] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||||
# shape (bs, seq_len, hidden_dim)
|
# shape (bs, seq_len, hidden_dim)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -455,7 +454,8 @@ class GPTJRBlock(nn.Module):
|
|||||||
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
self_attention_residual = attn_output + feed_forward_hidden_states + residual
|
||||||
|
|
||||||
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||||
# may not need, can norm encodings
|
if encoder_hidden_states.dtype != hidden_states.dtype:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(hidden_states.dtype)
|
||||||
encoder_normed = self.ln_2(encoder_hidden_states)
|
encoder_normed = self.ln_2(encoder_hidden_states)
|
||||||
|
|
||||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||||
@ -473,24 +473,22 @@ class GPTJRBlock(nn.Module):
|
|||||||
cross_attn_output[0]
|
cross_attn_output[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
alpha = self.alpha if self.training else 0.5
|
if step is not None:
|
||||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
alpha = self._update_alpha(step)
|
||||||
|
alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
|
||||||
|
else:
|
||||||
|
alpha = 0.5
|
||||||
|
|
||||||
|
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||||
if use_cache:
|
if use_cache:
|
||||||
outputs = (hidden_states,) + outputs
|
outputs = (hidden_states,) + outputs
|
||||||
else:
|
else:
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
outputs = (hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
# if training update alpha
|
|
||||||
if self.training:
|
|
||||||
self.step += 1
|
|
||||||
self._update_alpha(self.step)
|
|
||||||
|
|
||||||
|
|
||||||
return outputs # hidden_states, present, (attentions)
|
return outputs # hidden_states, present, (attentions)
|
||||||
|
|
||||||
def _update_alpha(self, iteration):
|
def _update_alpha(self, iteration):
|
||||||
self.alpha.data = torch.clamp(torch.tensor([1 / (iteration * self.world_size) ** 0.05]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
|
return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
|
||||||
|
|
||||||
|
|
||||||
class GPTJRPreTrainedModel(PreTrainedModel):
|
class GPTJRPreTrainedModel(PreTrainedModel):
|
||||||
@ -597,6 +595,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
step: Optional[int] = None,
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@ -704,7 +703,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, use_cache, output_attentions)
|
return module(*inputs, use_cache, output_attentions, step)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@ -725,6 +724,7 @@ class GPTJRModel(GPTJRPreTrainedModel):
|
|||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
step=step
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
@ -765,11 +765,6 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.transformer = GPTJRModel(config)
|
self.transformer = GPTJRModel(config)
|
||||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size)
|
||||||
if config.encoder_path is not None:
|
|
||||||
self.encoder = AutoModel.from_pretrained(config.encoder_path)
|
|
||||||
# freeze encoder and don't get gradiets
|
|
||||||
self.encoder.requires_grad_(False)
|
|
||||||
|
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
@ -832,23 +827,20 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: torch.LongTensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
decoder_input_ids: Optional[torch.FloatTensor] = None,
|
|
||||||
decoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
cross_attn_head_mask: Optional[torch.FloatTensor] = None,
|
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
step: Optional[int] = None
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@ -859,22 +851,9 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Encode if needed (training, first prediction pass)
|
|
||||||
if encoder_outputs is None:
|
|
||||||
# Convert encoder inputs in embeddings if needed
|
|
||||||
encoder_outputs = self.encoder(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
head_mask=head_mask,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
encoder_hidden_states=encoder_outputs,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
@ -885,6 +864,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
step=step,
|
||||||
)
|
)
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
|
50
gpt4all/train/metrics.py
Normal file
50
gpt4all/train/metrics.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from collections import Counter
|
||||||
|
import string
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
# adapted from huggingface
|
||||||
|
def normalize_answer(s):
|
||||||
|
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||||
|
|
||||||
|
def remove_articles(text):
|
||||||
|
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||||
|
|
||||||
|
def white_space_fix(text):
|
||||||
|
return " ".join(text.split())
|
||||||
|
|
||||||
|
def remove_punc(text):
|
||||||
|
exclude = set(string.punctuation)
|
||||||
|
return "".join(ch for ch in text if ch not in exclude)
|
||||||
|
|
||||||
|
def lower(text):
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||||
|
|
||||||
|
|
||||||
|
def f1_score(predictions, ground_truths):
|
||||||
|
total_f1 = []
|
||||||
|
for prediction, ground_truth in zip(predictions, ground_truths):
|
||||||
|
prediction_tokens = normalize_answer(prediction).split()
|
||||||
|
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||||
|
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||||
|
num_same = sum(common.values())
|
||||||
|
if num_same == 0:
|
||||||
|
total_f1.append(0)
|
||||||
|
continue
|
||||||
|
|
||||||
|
precision = 1.0 * num_same / len(prediction_tokens)
|
||||||
|
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||||
|
f1 = (2 * precision * recall) / (precision + recall)
|
||||||
|
total_f1.append(f1)
|
||||||
|
|
||||||
|
return total_f1
|
||||||
|
|
||||||
|
|
||||||
|
def exact_match_score(predictions, ground_truths):
|
||||||
|
exact_scores = []
|
||||||
|
for prediction, ground_truth in zip(predictions, ground_truths):
|
||||||
|
exact_scores.append(normalize_answer(prediction) == normalize_answer(ground_truth))
|
||||||
|
|
||||||
|
return exact_scores
|
Loading…
Reference in New Issue
Block a user