mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-06 03:56:45 +00:00
feat: forward for pythiaseek
This commit is contained in:
parent
d1b64d7eed
commit
71db2b664a
@ -129,4 +129,9 @@ class PythiaSeekConfig(PretrainedConfig):
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
self.total_alpha_steps = total_alpha_steps
|
||||
self.initial_alpha = initial_alpha
|
||||
self.final_alpha = final_alpha
|
@ -233,6 +233,131 @@ class PythiaSeekAttention(nn.Module):
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
return attn_output, attn_weights
|
||||
|
||||
class PythiaSeekCrossAttention(PythiaSeekAttention):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
max_positions = config.max_position_embeddings
|
||||
self.register_buffer(
|
||||
"bias",
|
||||
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||
1, 1, max_positions, max_positions
|
||||
),
|
||||
)
|
||||
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_attention_heads
|
||||
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
||||
f" `num_attention_heads`: {self.num_attention_heads})."
|
||||
)
|
||||
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||
|
||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||
|
||||
|
||||
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||
tensor = tensor.view(new_shape)
|
||||
|
||||
return tensor.permute(0, 1, 3, 2)
|
||||
|
||||
|
||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
"""
|
||||
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||
"""
|
||||
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def _attn(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
):
|
||||
|
||||
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||
attn_weights = torch.matmul(query, key)
|
||||
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||
attn_weights = attn_weights.to(value.dtype)
|
||||
|
||||
# Mask heads if we want to
|
||||
if head_mask is not None:
|
||||
attn_weights = attn_weights * head_mask
|
||||
|
||||
# value -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
attn_output = torch.matmul(attn_weights, value)
|
||||
|
||||
return attn_output, attn_weights
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
output_attentions: Optional[bool] = False,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||
]:
|
||||
|
||||
query = self.q_proj(hidden_states)
|
||||
# if we are doing cross attention
|
||||
key = self.k_proj(encoder_hidden_states)
|
||||
value = self.v_proj(encoder_hidden_states)
|
||||
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
|
||||
query = self._split_heads(query, self.num_attention_heads, self.head_dim)
|
||||
# (bs, dim) -> (bs, num_attention_heads, head_dim)
|
||||
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
|
||||
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
||||
|
||||
|
||||
value = value.permute(0, 3, 1, 2)
|
||||
key = key.permute(0, 3, 2, 1)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key = layer_past[0]
|
||||
past_value = layer_past[1]
|
||||
key = torch.cat((past_key, key), dim=-2)
|
||||
value = torch.cat((past_value, value), dim=-2)
|
||||
|
||||
if use_cache is True:
|
||||
present = (key, value)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# compute self-attention: V x Softmax(QK^T)
|
||||
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
outputs = (attn_output, present)
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs # a, present, (attentions)
|
||||
|
||||
|
||||
def attention_mask_func(attention_scores, ltor_mask):
|
||||
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
||||
@ -308,18 +433,29 @@ class PythiaSeekLayer(nn.Module):
|
||||
self.attention = PythiaSeekAttention(config)
|
||||
self.mlp = PythiaSeekMLP(config)
|
||||
|
||||
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.cross_attn = PythiaSeekCrossAttention(config)
|
||||
self.cross_attn_mlp = PythiaSeekMLP(config)
|
||||
|
||||
self.total_alpha_steps = config.total_alpha_steps
|
||||
self.initial_alpha = config.initial_alpha
|
||||
self.final_alpha = config.final_alpha
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = False,
|
||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
step: Optional[int] = None,
|
||||
):
|
||||
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||
attention_layer_outputs = self.attention(
|
||||
self.input_layernorm(hidden_states),
|
||||
ln_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
layer_past=layer_past,
|
||||
@ -334,74 +470,75 @@ class PythiaSeekLayer(nn.Module):
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
||||
hidden_states = mlp_output + attn_output + hidden_states
|
||||
self_attention_residual = mlp_output + attn_output + hidden_states
|
||||
else:
|
||||
# pseudocode:
|
||||
# x = x + attn(ln1(x))
|
||||
# x = x + mlp(ln2(x))
|
||||
attn_output = attn_output + hidden_states
|
||||
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
|
||||
hidden_states = mlp_output + attn_output
|
||||
self_attention_residual = mlp_output + attn_output
|
||||
|
||||
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
|
||||
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
|
||||
|
||||
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
||||
|
||||
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||
cross_attn_output = self.cross_attn(
|
||||
ln_hidden_states,
|
||||
encoder_hidden_states=encoder_normed,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
|
||||
cross_attn_ff = self.cross_attn_mlp(
|
||||
cross_attn_output[0]
|
||||
)
|
||||
|
||||
if step is not None:
|
||||
alpha = self._update_alpha(step)
|
||||
else:
|
||||
alpha = 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
||||
outputs = (hidden_states,) + outputs[1:]
|
||||
|
||||
return outputs
|
||||
return outputs # hidden_states, present, (attentions)
|
||||
|
||||
def _update_alpha(self, current_step):
|
||||
"""
|
||||
Computes the learning rate for the current step using a cosine decay schedule.
|
||||
|
||||
Args:
|
||||
initial_lr (float): The initial learning rate.
|
||||
final_lr (float): The final learning rate.
|
||||
total_steps (int): The total number of steps in the schedule.
|
||||
current_step (int): The current step.
|
||||
|
||||
GPT_NEOX_START_DOCSTRING = r"""
|
||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
||||
behavior.
|
||||
Returns:
|
||||
float: The learning rate for the current step.
|
||||
"""
|
||||
initial_alpha = 1
|
||||
final_alpha = .5
|
||||
if current_step >= self.total_alpha_steps:
|
||||
return final_alpha
|
||||
|
||||
Parameters:
|
||||
config ([`~PythiaSeekConfig`]): Model configuration class with all the parameters of the model.
|
||||
Initializing with a config file does not load the weights associated with the model, only the
|
||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
# Compute the cosine decay factor
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps))
|
||||
|
||||
GPT_NEOX_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
input_ids (`torch.LongTensor` of shape `({0})`):
|
||||
Indices of input sequence tokens in the vocabulary.
|
||||
# Compute the current learning rate
|
||||
alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay
|
||||
|
||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
||||
[`PreTrainedTokenizer.__call__`] for details.
|
||||
return alpha
|
||||
|
||||
[What are input IDs?](../glossary#input-ids)
|
||||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
||||
config.n_positions - 1]`.
|
||||
|
||||
[What are position IDs?](../glossary#position-ids)
|
||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 indicates the head is **not masked**,
|
||||
- 0 indicates the head is **masked**.
|
||||
|
||||
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
||||
model's internal embedding lookup matrix.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
||||
tensors for more detail.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||
more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||
@ -427,6 +564,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
@ -436,6 +574,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
step: Optional[int] = None
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||
@ -528,26 +667,29 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for layer_past
|
||||
return module(*inputs, use_cache, None, output_attentions)
|
||||
return module(*inputs, use_cache, None, output_attentions, step)
|
||||
|
||||
return custom_forward
|
||||
|
||||
outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
head_mask[i],
|
||||
)
|
||||
else:
|
||||
outputs = layer(
|
||||
hidden_states,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask[i],
|
||||
layer_past=layer_past,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
step=step,
|
||||
)
|
||||
hidden_states = outputs[0]
|
||||
if use_cache is True:
|
||||
@ -577,9 +719,16 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
self.gpt_neox = PythiaSeekModel(config)
|
||||
self.pythiaseek = PythiaSeekModel(config)
|
||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
self.encoder_dim = config.encoder_dim
|
||||
|
||||
if self.hidden_size != self.encoder_dim:
|
||||
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
|
||||
nn.Linear(config.hidden_size * 4, config.hidden_size))
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@ -591,17 +740,19 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
input_ids: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
step: Optional[int] = None
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||
@ -644,8 +795,13 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.gpt_neox(
|
||||
if self.hidden_size != self.encoder_dim:
|
||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||
|
||||
outputs = self.pythiaseek(
|
||||
input_ids,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
@ -655,6 +811,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
step=step,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
55
gpt4all/models/pythiaseek/test_pythiaseek.py
Normal file
55
gpt4all/models/pythiaseek/test_pythiaseek.py
Normal file
@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from gpt4all.models import PythiaSeekConfig, PythiaSeekForCausalLM
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
|
||||
# seed torch
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
config = PythiaSeekConfig(encoder_dim=384,
|
||||
hidden_size=1024,
|
||||
intermediate_size=4096,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4)
|
||||
print("loaded config")
|
||||
|
||||
print("loading model")
|
||||
model = PythiaSeekForCausalLM(config)
|
||||
print("loaded model")
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.model_max_length = 1024
|
||||
|
||||
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||
|
||||
|
||||
def mean_pooling(model_output, attention_mask):
|
||||
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||
|
||||
text = "The quick brown fox jumps over the lazy dog."
|
||||
print("Encoded knn")
|
||||
tokenized = encoder_tokenizer(text, return_tensors="pt")
|
||||
|
||||
# bs, seq_len, dim
|
||||
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
|
||||
|
||||
# make 2 neighbors
|
||||
# (bs, knn, encoding_dim)
|
||||
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
|
||||
|
||||
inputs = "What did the fox do?"
|
||||
|
||||
print("Encoded inputs")
|
||||
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
|
||||
|
||||
print("Running model")
|
||||
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
print(outputs)
|
||||
print(outputs[0].shape)
|
||||
|
Loading…
Reference in New Issue
Block a user