feat: forward for pythiaseek

This commit is contained in:
Zach Nussbaum 2023-05-04 03:16:33 +00:00
parent d1b64d7eed
commit 71db2b664a
3 changed files with 278 additions and 61 deletions

View File

@ -129,4 +129,9 @@ class PythiaSeekConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual
self.use_parallel_residual = use_parallel_residual
self.encoder_dim = encoder_dim
self.total_alpha_steps = total_alpha_steps
self.initial_alpha = initial_alpha
self.final_alpha = final_alpha

View File

@ -233,6 +233,131 @@ class PythiaSeekAttention(nn.Module):
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
class PythiaSeekCrossAttention(PythiaSeekAttention):
def __init__(self, config):
super().__init__(config)
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))
self.embed_dim = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_attention_heads
if self.head_dim * self.num_attention_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
f" `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 1, 3, 2)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into hidden dim
"""
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def _attn(
self,
query,
key,
value,
attention_mask=None,
head_mask=None,
):
# query -> (bs, num_attention_heads, seq_len, head_dim)
# key -> (bs, num_attention_heads, head_dim, neighbors)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
attn_weights = torch.matmul(query, key)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.to(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
# value -> (bs, num_attention_heads, seq_len, head_dim)
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> Union[
Tuple[torch.Tensor, Tuple[torch.Tensor]],
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
]:
query = self.q_proj(hidden_states)
# if we are doing cross attention
key = self.k_proj(encoder_hidden_states)
value = self.v_proj(encoder_hidden_states)
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
query = self._split_heads(query, self.num_attention_heads, self.head_dim)
# (bs, dim) -> (bs, num_attention_heads, head_dim)
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
value = value.permute(0, 3, 1, 2)
key = key.permute(0, 3, 2, 1)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = (key, value)
else:
present = None
# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
def attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
@ -308,18 +433,29 @@ class PythiaSeekLayer(nn.Module):
self.attention = PythiaSeekAttention(config)
self.mlp = PythiaSeekMLP(config)
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.cross_attn = PythiaSeekCrossAttention(config)
self.cross_attn_mlp = PythiaSeekMLP(config)
self.total_alpha_steps = config.total_alpha_steps
self.initial_alpha = config.initial_alpha
self.final_alpha = config.final_alpha
def forward(
self,
hidden_states: Optional[torch.FloatTensor],
encoder_hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
step: Optional[int] = None,
):
ln_hidden_states = self.input_layernorm(hidden_states)
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
ln_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
@ -334,74 +470,75 @@ class PythiaSeekLayer(nn.Module):
# pseudocode:
# x = x + attn(ln1(x)) + mlp(ln2(x))
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = mlp_output + attn_output + hidden_states
self_attention_residual = mlp_output + attn_output + hidden_states
else:
# pseudocode:
# x = x + attn(ln1(x))
# x = x + mlp(ln2(x))
attn_output = attn_output + hidden_states
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
hidden_states = mlp_output + attn_output
self_attention_residual = mlp_output + attn_output
# encoder_hidden_states -> (bs, knn, encoder_dim)
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
# cross_attn_outputs -> (bs, seq_len, dim)
cross_attn_output = self.cross_attn(
ln_hidden_states,
encoder_hidden_states=encoder_normed,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
cross_attn_ff = self.cross_attn_mlp(
cross_attn_output[0]
)
if step is not None:
alpha = self._update_alpha(step)
else:
alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
if use_cache:
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
outputs = (hidden_states,) + outputs[1:]
return outputs
return outputs # hidden_states, present, (attentions)
def _update_alpha(self, current_step):
"""
Computes the learning rate for the current step using a cosine decay schedule.
Args:
initial_lr (float): The initial learning rate.
final_lr (float): The final learning rate.
total_steps (int): The total number of steps in the schedule.
current_step (int): The current step.
GPT_NEOX_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Returns:
float: The learning rate for the current step.
"""
initial_alpha = 1
final_alpha = .5
if current_step >= self.total_alpha_steps:
return final_alpha
Parameters:
config ([`~PythiaSeekConfig`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
# Compute the cosine decay factor
cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps))
GPT_NEOX_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `({0})`):
Indices of input sequence tokens in the vocabulary.
# Compute the current learning rate
alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
return alpha
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
"""
class PythiaSeekModel(PythiaSeekPreTrainedModel):
@ -427,6 +564,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
@ -436,6 +574,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
@ -528,26 +667,29 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
def create_custom_forward(module):
def custom_forward(*inputs):
# None for layer_past
return module(*inputs, use_cache, None, output_attentions)
return module(*inputs, use_cache, None, output_attentions, step)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
encoder_hidden_states,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
output_attentions=output_attentions,
step=step,
)
hidden_states = outputs[0]
if use_cache is True:
@ -577,9 +719,16 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.gpt_neox = PythiaSeekModel(config)
self.pythiaseek = PythiaSeekModel(config)
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.hidden_size = config.hidden_size
self.encoder_dim = config.encoder_dim
if self.hidden_size != self.encoder_dim:
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
nn.Linear(config.hidden_size * 4, config.hidden_size))
# Initialize weights and apply final processing
self.post_init()
@ -591,17 +740,19 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
step: Optional[int] = None
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
@ -644,8 +795,13 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.gpt_neox(
if self.hidden_size != self.encoder_dim:
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
outputs = self.pythiaseek(
input_ids,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
@ -655,6 +811,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
step=step,
)
hidden_states = outputs[0]

View File

@ -0,0 +1,55 @@
import torch
from gpt4all.models import PythiaSeekConfig, PythiaSeekForCausalLM
from transformers import AutoTokenizer, AutoModel
# seed torch
torch.manual_seed(0)
config = PythiaSeekConfig(encoder_dim=384,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=4,
num_attention_heads=4)
print("loaded config")
print("loading model")
model = PythiaSeekForCausalLM(config)
print("loaded model")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = 1024
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
text = "The quick brown fox jumps over the lazy dog."
print("Encoded knn")
tokenized = encoder_tokenizer(text, return_tensors="pt")
# bs, seq_len, dim
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
# make 2 neighbors
# (bs, knn, encoding_dim)
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
inputs = "What did the fox do?"
print("Encoded inputs")
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
print("Running model")
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
print(outputs)
print(outputs[0].shape)