Merge branch 'junior' of https://github.com/nomic-ai/gpt4all into junior

This commit is contained in:
Zach Nussbaum 2023-05-05 02:18:10 +00:00
commit 4d5e48a7a9
13 changed files with 1156 additions and 56 deletions

View File

@ -89,7 +89,7 @@ Please see [GPT4All-J Technical Report](https://static.nomic.ai/gpt4all/2023_GPT
We have released updated versions of our `GPT4All-J` model and training data.
- `v1.0`: The original model trained on the v1.0 dataset
- `v1.1-breezy`: Trained on afiltered dataset where we removed all instances of AI language model
- `v1.1-breezy`: Trained on a filtered dataset where we removed all instances of AI language model
- `v1.2-jazzy`: Trained on a filtered dataset where we also removed instances like I'm sorry, I can't answer... and AI language model
The [models](https://huggingface.co/nomic-ai/gpt4all-j) and [data](https://huggingface.co/datasets/nomic-ai/gpt4all-j-prompt-generations) versions can be specified by passing a `revision` argument.

View File

@ -35,14 +35,5 @@
],
"eps": 1e-08
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": 0,
"warmup_max_lr": "auto",
"warmup_num_steps": "auto",
"warmup_type": "linear"
}
}
}

View File

@ -1,9 +1,10 @@
# model/tokenizer
model_name: "EleutherAI/gpt-j-6B"
tokenizer_name: "EleutherAI/gpt-j-6B"
model_name: "EleutherAI/gpt-j-6b"
tokenizer_name: "EleutherAI/gpt-j-6b"
version: null
gradient_checkpointing: true
save_name: "nomic-ai/gpt-jr-decay-alpha"
push_to_hub: false
encoder_dim: 384
# dataset
@ -11,7 +12,7 @@ streaming: false
num_proc: 64
dataset_path: "/home/paperspace/gpt4all/gpt4all/index/squad_supplemented_train"
max_length: 1024
batch_size: 32
batch_size: 8
pct_test: 0.05
q_column: "question"
a_column: "answers"
@ -23,7 +24,7 @@ lr: 1.0e-4
min_lr: 0
weight_decay: 0.0
eval_every: 50
save_every: 500
save_every: -1
log_grads_every: 100
log_lr_every: 10
output_dir: "ckpts/decay_alpha"
@ -31,6 +32,8 @@ checkpoint: null
lora: false
warmup_steps: 500
num_epochs: 5
debug: false
scheduler: false
# logging
wandb: true

View File

@ -1,8 +1,12 @@
from .configuration_gpt_jr import GPTJRConfig
from .modeling_gpt_jr import GPTJRForCausalLM
from .gpt_jr.configuration_gpt_jr import GPTJRConfig
from .gpt_jr.modeling_gpt_jr import GPTJRForCausalLM
from .pythiaseek import PythiaSeekForCausalLM, PythiaSeekConfig
__all__ = [
"GPTJRConfig",
"GPTJRForCausalLM"
"GPTJRForCausalLM",
"PythiaSeekConfig",
"PythiaSeekForCausalLM",
]

View File

View File

@ -31,7 +31,7 @@ from transformers.modeling_outputs import (
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from gpt4all.models.configuration_gpt_jr import GPTJRConfig
from gpt4all.models.gpt_jr.configuration_gpt_jr import GPTJRConfig
logger = logging.get_logger(__name__)
@ -476,14 +476,13 @@ class GPTJRBlock(nn.Module):
cross_attn_output[0]
)
# if step is not None:
# alpha = self._update_alpha(step)
# #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
# else:
# alpha = 0.5
if step is not None:
alpha = self._update_alpha(step)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
# hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
hidden_states = cross_attn_ff + self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
else:
@ -491,10 +490,6 @@ class GPTJRBlock(nn.Module):
return outputs # hidden_states, present, (attentions)
# def _update_alpha(self, iteration):
# return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
def _update_alpha(self, current_step):
"""
Computes the learning rate for the current step using a cosine decay schedule.
@ -801,7 +796,8 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
self.encoder_dim = config.encoder_dim
if self.hidden_size != self.encoder_dim:
self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd)
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.n_embd * 4),
nn.Linear(config.n_embd * 4, config.n_embd))
# Model parallel
self.model_parallel = False
@ -889,7 +885,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.hidden_size != self.encoder_dim:
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype)
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
transformer_outputs = self.transformer(

View File

@ -35,7 +35,7 @@ encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors
# (bs, knn, encoding_dim)
encoder_outputs = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?"
@ -43,7 +43,7 @@ print("Encoded inputs")
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model")
outputs = model(**tokenized_input, encoder_outputs=encoder_outputs)
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
print(outputs)
print(outputs[0].shape)

View File

@ -0,0 +1,2 @@
from .configuration_pythiaseek import PythiaSeekConfig
from .modeling_pythiaseek import PythiaSeekForCausalLM

View File

@ -0,0 +1,137 @@
# coding=utf-8
# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" GPTNeoX model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
GPT_NEOX_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"EleutherAI/gpt-neox-20b": "https://huggingface.co/EleutherAI/gpt-neox-20b/resolve/main/config.json",
# See all GPTNeoX models at https://huggingface.co/models?filter=gpt_neox
}
class PythiaSeekConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GPTNeoXModel`]. It is used to instantiate an
GPTNeoX model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of the GPTNeoX
[EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 50432):
Vocabulary size of the GPTNeoX model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`GPTNeoXModel`].
hidden_size (`int`, *optional*, defaults to 6144):
Dimension of the encoder layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 44):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 64):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 24576):
Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` are supported.
rotary_pct (`float`, *optional*, defaults to 0.25):
percentage of hidden dimensions to allocate to rotary embeddings
rotary_emb_base (`int`, *optional*, defaults to 10000)
base for computing rotary embeddings frequency
classifier_dropout (`float`, *optional*, defaults to 0.1):
Argument used when doing token classification, used in the model [`GPTNeoXForTokenClassification`].
The dropout ratio for the hidden layer.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
initializer_range (`float`, *optional*, defaults to 1e-5):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
The epsilon used by the layer normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B).
Example:
```python
>>> from transformers import GPTNeoXConfig, GPTNeoXModel
>>> # Initializing a GPTNeoX gpt-neox-20b style configuration
>>> configuration = GPTNeoXConfig()
>>> # Initializing a model (with random weights) from the gpt-neox-20b style configuration
>>> model = GPTNeoXModel(configuration) # doctest: +SKIP
>>> # Accessing the model configuration
>>> configuration = model.config # doctest: +SKIP
```"""
model_type = "gpt_neox"
def __init__(
self,
vocab_size=50432,
hidden_size=6144,
num_hidden_layers=44,
num_attention_heads=64,
intermediate_size=24576,
hidden_act="gelu",
rotary_pct=0.25,
rotary_emb_base=10000,
classifier_dropout=0.1,
max_position_embeddings=2048,
initializer_range=0.02,
layer_norm_eps=1e-5,
use_cache=True,
bos_token_id=0,
eos_token_id=2,
tie_word_embeddings=False,
use_parallel_residual=True,
encoder_dim=4096,
total_alpha_steps=0,
initial_alpha=1,
final_alpha=.5,
**kwargs,
):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.classifier_dropout = classifier_dropout
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
self.encoder_dim = encoder_dim
self.total_alpha_steps = total_alpha_steps
self.initial_alpha = initial_alpha
self.final_alpha = final_alpha

View File

@ -0,0 +1,885 @@
# coding=utf-8
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch PythiaSeek model."""
from typing import Optional, Tuple, Union
import math
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from gpt4all.models.pythiaseek import PythiaSeekConfig
logger = logging.get_logger(__name__)
GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST = [
"EleutherAI/gpt-neox-20b",
# See all PythiaSeek models at https://huggingface.co/models?filter=gpt_neox
]
class PythiaSeekPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = PythiaSeekConfig
base_model_prefix = "pythiaseek"
supports_gradient_checkpointing = True
_no_split_modules = ["PythiaSeekLayer"]
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, PythiaSeekModel):
module.gradient_checkpointing = value
class PythiaSeekAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.rotary_emb = RotaryEmbedding(
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.register_buffer(
"norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
persistent=False,
)
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
has_layer_past = layer_past is not None
# Compute QKV
# Attention heads [batch, seq_len, hidden_size]
# --> [batch, seq_len, (np * 3 * head_size)]
qkv = self.query_key_value(hidden_states)
# [batch, seq_len, (num_heads * 3 * head_size)]
# --> [batch, seq_len, num_heads, 3 * head_size]
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
qkv = qkv.view(*new_qkv_shape)
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
# Compute rotary embeddings on rotary_ndims
query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
if has_layer_past:
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)
# Cache QKV values
if has_layer_past:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
# Reshape outputs
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
attn_output = self.dense(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs
@classmethod
def _split_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Splits hidden dim into attn_head_size and num_attention_heads
"""
# tensor: [bs, seq_len, hidden_size]
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(new_shape)
# -> [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3)
return tensor
@classmethod
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor [bs, num_attention_heads, seq_len, attn_head_size]
tensor = tensor.permute(0, 2, 1, 3).contiguous()
# -> [bs, seq_len, num_attention_heads, attn_head_size]
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size)
# -> [bs, seq_len, hidden_size]
return tensor
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size]
# compute causal mask from causal mask buffer
batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size)
attn_scores = torch.zeros(
batch_size * num_attention_heads,
query_length,
key_length,
dtype=query.dtype,
device=key.device,
)
attn_scores = torch.baddbmm(
attn_scores,
query,
key.transpose(1, 2),
beta=1.0,
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor),
)
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length)
mask_value = torch.finfo(attn_scores.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device)
attn_scores = torch.where(causal_mask, attn_scores, mask_value)
if attention_mask is not None:
# Apply the attention mask
attn_scores = attn_scores + attention_mask
attn_weights = nn.functional.softmax(attn_scores, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class PythiaSeekCrossAttention(PythiaSeekAttention):
def __init__(self, config):
super().__init__(config)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 1, 3, 2)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# value -> (bs, num_attention_heads, seq_len, head_dim)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
# if we are doing cross attention
key = self.k_proj(encoder_hidden_states)
value = self.v_proj(encoder_hidden_states)
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
query = self._split_heads(query, self.num_attention_heads, self.head_dim)
# (bs, dim) -> (bs, num_attention_heads, head_dim)
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
value = value.permute(0, 3, 1, 2)
key = key.permute(0, 3, 2, 1)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
return attention_scores
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class PythiaSeekMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size)
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size)
self.act = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
hidden_states = self.dense_h_to_4h(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.dense_4h_to_h(hidden_states)
return hidden_states
class PythiaSeekLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.use_parallel_residual = config.use_parallel_residual
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.attention = PythiaSeekAttention(config)
self.mlp = PythiaSeekMLP(config)
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = PythiaSeekCrossAttention(config)
self.cross_attn_mlp = PythiaSeekMLP(config)
self.total_alpha_steps = config.total_alpha_steps
self.initial_alpha = config.initial_alpha
self.final_alpha = config.final_alpha
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
step: Optional[int] = None,
):
ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention(
ln_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights)
outputs = attention_layer_outputs[1:]
if self.use_parallel_residual:
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
self_attention_residual = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
self_attention_residual = mlp_output + attn_output
# encoder_hidden_states -> (bs, knn, encoder_dim)
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
ln_hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output[0]
)
if step is not None:
alpha = self._update_alpha(step)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs # hidden_states, present, (attentions)
def _update_alpha(self, current_step):
"""
Computes the learning rate for the current step using a cosine decay schedule.
Args:
initial_lr (float): The initial learning rate.
final_lr (float): The final learning rate.
total_steps (int): The total number of steps in the schedule.
current_step (int): The current step.
Returns:
float: The learning rate for the current step.
"""
initial_alpha = 1
final_alpha = .5
if current_step >= self.total_alpha_steps:
return final_alpha
# Compute the cosine decay factor
cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps))
# Compute the current learning rate
alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay
return alpha
class PythiaSeekModel(PythiaSeekPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([PythiaSeekLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_in
def set_input_embeddings(self, value):
self.embed_in = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# Attention mask.
if attention_mask is not None:
assert batch_size > 0, "batch_size has to be defined and > 0"
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions, step)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
encoder_hidden_states,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = layer(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
step=step,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_attentions = all_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.final_layer_norm(hidden_states)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_attentions,
)
class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
def __init__(self, config):
super().__init__(config)
self.pythiaseek = PythiaSeekModel(config)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hidden_size = config.hidden_size
self.encoder_dim = config.encoder_dim
if self.hidden_size != self.encoder_dim:
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
nn.Linear(config.hidden_size * 4, config.hidden_size))
# Initialize weights and apply final processing
self.post_init()
def get_output_embeddings(self):
return self.embed_out
def set_output_embeddings(self, new_embeddings):
self.embed_out = new_embeddings
def forward(
self,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are
only required when the model is used as a decoder in a Sequence to Sequence model.
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see
`past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, PythiaSeekForCausalLM, PythiaSeekConfig
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config = PythiaSeekConfig.from_pretrained("EleutherAI/gpt-neox-20b")
>>> config.is_decoder = True
>>> model = PythiaSeekForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.hidden_size != self.encoder_dim:
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
outputs = self.pythiaseek(
input_ids,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
step=step,
)
hidden_states = outputs[0]
lm_logits = self.embed_out(hidden_states)
lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
# we are doing next-token prediction; shift prediction scores and input ids by one
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithPast(
loss=lm_loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
input_shape = input_ids.shape
# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"position_ids": position_ids,
}
)
return model_inputs
def _reorder_cache(self, past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
)
return reordered_past

View File

@ -0,0 +1,55 @@
import torch
from gpt4all.models import PythiaSeekConfig, PythiaSeekForCausalLM
from transformers import AutoTokenizer, AutoModel
# seed torch
torch.manual_seed(0)
config = PythiaSeekConfig(encoder_dim=384,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=4,
num_attention_heads=4)
print("loaded config")
print("loading model")
model = PythiaSeekForCausalLM(config)
print("loaded model")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 1024
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
text = "The quick brown fox jumps over the lazy dog."
print("Encoded knn")
tokenized = encoder_tokenizer(text, return_tensors="pt")
# bs, seq_len, dim
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors
# (bs, knn, encoding_dim)
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?"
print("Encoded inputs")
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model")
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
print(outputs)
print(outputs[0].shape)

View File

@ -75,13 +75,15 @@ def train(accelerator, config):
revision=config['version'] if 'version' in config else None,
use_cache=False if checkpoint else True,
encoder_dim=config["encoder_dim"],
total_alpha_steps=total_num_steps
# hardcoded!! TODO: change to config
total_alpha_steps=10000,
)
else:
model = AutoModelForCausalLM.from_pretrained(config["model_name"],
use_cache=False if checkpoint else True,
trust_remote_code=True)
accelerator.print(f"Training a {model.num_parameters():,} parameter model")
if checkpoint:
model.gradient_checkpointing_enable()
@ -106,6 +108,7 @@ def train(accelerator, config):
# Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler
if config["scheduler"] or "scheduler" in accelerator.state.deepspeed_plugin.deepspeed_config:
if (
accelerator.state.deepspeed_plugin is None
or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
@ -120,13 +123,19 @@ def train(accelerator, config):
scheduler = DummyScheduler(
optimizer, total_num_steps=total_num_steps, warmup_num_steps=config["warmup_steps"]
)
model, optimizer, train_dataloader, val_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader, scheduler
)
scheduler = True
else:
model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
model, optimizer, train_dataloader, val_dataloader
)
scheduler = False
# setup for saving training states in case preemption
if scheduler:
accelerator.register_for_checkpointing(scheduler)
if config["checkpoint"]:
@ -154,6 +163,22 @@ def train(accelerator, config):
encoder_hidden_states=batch["encoder_hidden_states"],
step=step)
loss = outputs.loss
accelerator.print(f"Loss: {loss:.4f}")
if config["debug"]:
logits = outputs.logits
pred_tokens = torch.argmax(logits, dim=-1)
labels = batch["labels"].clone()
for row in range(labels.shape[0]):
curr_label = labels[row]
mask = curr_label != -100
decoded_true = tokenizer.decode(curr_label[mask], skip_special_tokens=True)
decoded_pred = tokenizer.decode(pred_tokens[row][mask], skip_special_tokens=True)
accelerator.print(f"Predicted tokens: {decoded_pred}")
accelerator.print(f"True tokens: {decoded_true}")
# gather loss before backprop in case of gradient accumulation
loss_values = accelerator.gather_for_metrics({"loss": loss.detach().float()})
@ -167,15 +192,17 @@ def train(accelerator, config):
if config["wandb"]:
if step > 0 and step % (config["log_lr_every"] ) == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.log({"lr": scheduler.get_last_lr()[0]}, step=curr_step)
lr = optimizer.param_groups[0]["lr"]
accelerator.log({"lr": lr}, step=curr_step)
if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
if scheduler:
scheduler.step()
optimizer.zero_grad()
if step > 0 and step % config["save_every"] == 0:
if step > 0 and config["save_every"] > 0 and step % config["save_every"] == 0:
curr_step = step + epoch * len(train_dataloader)
accelerator.save_state(f"{config['output_dir']}/step_{curr_step}")
@ -194,17 +221,17 @@ def train(accelerator, config):
curr_step = step + epoch * len(train_dataloader)
accelerator.log({**log_train, **log_val}, step=curr_step)
accelerator.print(f"Current LR: {scheduler.get_last_lr()[0]}")
accelerator.print(f"Current LR: {optimizer.param_groups[0]['lr']}")
accelerator.print(format_metrics(log_train, "train", f" step {step} "))
accelerator.print(format_metrics(log_val, "val", f" step {step} "))
train_loss.reset()
accelerator.print(f"Epoch {epoch} finished")
accelerator.print(f"Pushing to HF hub")
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
if config["push_to_hub"]:
accelerator.print(f"Pushing to HF hub")
try:
if accelerator.is_main_process:
unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True)