mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 04:20:59 +00:00
feat: forward for pythiaseek
This commit is contained in:
parent
d1b64d7eed
commit
71db2b664a
@ -130,3 +130,8 @@ class PythiaSeekConfig(PretrainedConfig):
|
|||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.tie_word_embeddings = tie_word_embeddings
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
self.use_parallel_residual = use_parallel_residual
|
self.use_parallel_residual = use_parallel_residual
|
||||||
|
|
||||||
|
self.encoder_dim = encoder_dim
|
||||||
|
self.total_alpha_steps = total_alpha_steps
|
||||||
|
self.initial_alpha = initial_alpha
|
||||||
|
self.final_alpha = final_alpha
|
@ -233,6 +233,131 @@ class PythiaSeekAttention(nn.Module):
|
|||||||
attn_output = torch.matmul(attn_weights, value)
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
class PythiaSeekCrossAttention(PythiaSeekAttention):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
max_positions = config.max_position_embeddings
|
||||||
|
self.register_buffer(
|
||||||
|
"bias",
|
||||||
|
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view(
|
||||||
|
1, 1, max_positions, max_positions
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self.register_buffer("masked_bias", torch.tensor(-1e9))
|
||||||
|
|
||||||
|
self.embed_dim = config.hidden_size
|
||||||
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_attention_heads
|
||||||
|
if self.head_dim * self.num_attention_heads != self.embed_dim:
|
||||||
|
raise ValueError(
|
||||||
|
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
|
||||||
|
f" `num_attention_heads`: {self.num_attention_heads})."
|
||||||
|
)
|
||||||
|
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
|
||||||
|
|
||||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _split_knn_attn_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||||
|
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size)
|
||||||
|
tensor = tensor.view(new_shape)
|
||||||
|
|
||||||
|
return tensor.permute(0, 1, 3, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||||
|
"""
|
||||||
|
Merges attn_head_size dim and num_attn_heads dim into hidden dim
|
||||||
|
"""
|
||||||
|
# tensor -> (bs, seq_len, num_attention_heads, head_dim)
|
||||||
|
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
||||||
|
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||||
|
return tensor.view(new_shape)
|
||||||
|
|
||||||
|
def _attn(
|
||||||
|
self,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attention_mask=None,
|
||||||
|
head_mask=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
# query -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
|
# key -> (bs, num_attention_heads, head_dim, neighbors)
|
||||||
|
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||||
|
attn_weights = torch.matmul(query, key)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
attn_weights = attn_weights.to(value.dtype)
|
||||||
|
|
||||||
|
# Mask heads if we want to
|
||||||
|
if head_mask is not None:
|
||||||
|
attn_weights = attn_weights * head_mask
|
||||||
|
|
||||||
|
# value -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
|
# attn_weights -> (bs, num_attention_heads, seq_len, neighbors)
|
||||||
|
# attn_output -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
|
attn_output = torch.matmul(attn_weights, value)
|
||||||
|
|
||||||
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
) -> Union[
|
||||||
|
Tuple[torch.Tensor, Tuple[torch.Tensor]],
|
||||||
|
Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
|
||||||
|
]:
|
||||||
|
|
||||||
|
query = self.q_proj(hidden_states)
|
||||||
|
# if we are doing cross attention
|
||||||
|
key = self.k_proj(encoder_hidden_states)
|
||||||
|
value = self.v_proj(encoder_hidden_states)
|
||||||
|
# (bs, seq_len, dim) -> (bs, num_attention_heads, seq_len, head_dim)
|
||||||
|
query = self._split_heads(query, self.num_attention_heads, self.head_dim)
|
||||||
|
# (bs, dim) -> (bs, num_attention_heads, head_dim)
|
||||||
|
key = self._split_knn_attn_heads(key, self.num_attention_heads, self.head_dim)
|
||||||
|
value = self._split_knn_attn_heads(value, self.num_attention_heads, self.head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
value = value.permute(0, 3, 1, 2)
|
||||||
|
key = key.permute(0, 3, 2, 1)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key = layer_past[0]
|
||||||
|
past_value = layer_past[1]
|
||||||
|
key = torch.cat((past_key, key), dim=-2)
|
||||||
|
value = torch.cat((past_value, value), dim=-2)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
present = (key, value)
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
|
||||||
|
# compute self-attention: V x Softmax(QK^T)
|
||||||
|
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
||||||
|
|
||||||
|
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
outputs = (attn_output, present)
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (attn_weights,)
|
||||||
|
|
||||||
|
return outputs # a, present, (attentions)
|
||||||
|
|
||||||
|
|
||||||
def attention_mask_func(attention_scores, ltor_mask):
|
def attention_mask_func(attention_scores, ltor_mask):
|
||||||
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min)
|
||||||
@ -308,18 +433,29 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
self.attention = PythiaSeekAttention(config)
|
self.attention = PythiaSeekAttention(config)
|
||||||
self.mlp = PythiaSeekMLP(config)
|
self.mlp = PythiaSeekMLP(config)
|
||||||
|
|
||||||
|
self.cross_attn_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
self.cross_attn = PythiaSeekCrossAttention(config)
|
||||||
|
self.cross_attn_mlp = PythiaSeekMLP(config)
|
||||||
|
|
||||||
|
self.total_alpha_steps = config.total_alpha_steps
|
||||||
|
self.initial_alpha = config.initial_alpha
|
||||||
|
self.final_alpha = config.final_alpha
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[torch.FloatTensor],
|
hidden_states: Optional[torch.FloatTensor],
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor],
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
|
step: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
ln_hidden_states = self.input_layernorm(hidden_states)
|
||||||
attention_layer_outputs = self.attention(
|
attention_layer_outputs = self.attention(
|
||||||
self.input_layernorm(hidden_states),
|
ln_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
@ -334,74 +470,75 @@ class PythiaSeekLayer(nn.Module):
|
|||||||
# pseudocode:
|
# pseudocode:
|
||||||
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
# x = x + attn(ln1(x)) + mlp(ln2(x))
|
||||||
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
|
||||||
hidden_states = mlp_output + attn_output + hidden_states
|
self_attention_residual = mlp_output + attn_output + hidden_states
|
||||||
else:
|
else:
|
||||||
# pseudocode:
|
# pseudocode:
|
||||||
# x = x + attn(ln1(x))
|
# x = x + attn(ln1(x))
|
||||||
# x = x + mlp(ln2(x))
|
# x = x + mlp(ln2(x))
|
||||||
attn_output = attn_output + hidden_states
|
attn_output = attn_output + hidden_states
|
||||||
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
|
mlp_output = self.mlp(self.post_attention_layernorm(attn_output))
|
||||||
hidden_states = mlp_output + attn_output
|
self_attention_residual = mlp_output + attn_output
|
||||||
|
|
||||||
|
# encoder_hidden_states -> (bs, knn, encoder_dim)
|
||||||
|
if encoder_hidden_states.dtype != ln_hidden_states.dtype:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(ln_hidden_states.dtype)
|
||||||
|
|
||||||
|
encoder_normed = self.cross_attn_ln(encoder_hidden_states)
|
||||||
|
|
||||||
|
# cross_attn_outputs -> (bs, seq_len, dim)
|
||||||
|
cross_attn_output = self.cross_attn(
|
||||||
|
ln_hidden_states,
|
||||||
|
encoder_hidden_states=encoder_normed,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
head_mask=head_mask,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_attn_ff = self.cross_attn_mlp(
|
||||||
|
cross_attn_output[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
if step is not None:
|
||||||
|
alpha = self._update_alpha(step)
|
||||||
|
else:
|
||||||
|
alpha = 0.5
|
||||||
|
|
||||||
|
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights)
|
outputs = (hidden_states,) + outputs
|
||||||
else:
|
else:
|
||||||
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights)
|
outputs = (hidden_states,) + outputs[1:]
|
||||||
|
|
||||||
return outputs
|
return outputs # hidden_states, present, (attentions)
|
||||||
|
|
||||||
|
def _update_alpha(self, current_step):
|
||||||
|
"""
|
||||||
|
Computes the learning rate for the current step using a cosine decay schedule.
|
||||||
|
|
||||||
GPT_NEOX_START_DOCSTRING = r"""
|
|
||||||
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
|
|
||||||
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
|
||||||
behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`~PythiaSeekConfig`]): Model configuration class with all the parameters of the model.
|
|
||||||
Initializing with a config file does not load the weights associated with the model, only the
|
|
||||||
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
GPT_NEOX_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
Args:
|
||||||
input_ids (`torch.LongTensor` of shape `({0})`):
|
initial_lr (float): The initial learning rate.
|
||||||
Indices of input sequence tokens in the vocabulary.
|
final_lr (float): The final learning rate.
|
||||||
|
total_steps (int): The total number of steps in the schedule.
|
||||||
|
current_step (int): The current step.
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
Returns:
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
float: The learning rate for the current step.
|
||||||
|
"""
|
||||||
|
initial_alpha = 1
|
||||||
|
final_alpha = .5
|
||||||
|
if current_step >= self.total_alpha_steps:
|
||||||
|
return final_alpha
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
# Compute the cosine decay factor
|
||||||
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps))
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
# Compute the current learning rate
|
||||||
- 0 for tokens that are **masked**.
|
alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
return alpha
|
||||||
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
||||||
config.n_positions - 1]`.
|
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
|
||||||
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
|
||||||
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
||||||
is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
|
|
||||||
model's internal embedding lookup matrix.
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
||||||
@ -427,6 +564,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
@ -436,6 +574,7 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
step: Optional[int] = None
|
||||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||||||
@ -528,26 +667,29 @@ class PythiaSeekModel(PythiaSeekPreTrainedModel):
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for layer_past
|
# None for layer_past
|
||||||
return module(*inputs, use_cache, None, output_attentions)
|
return module(*inputs, use_cache, None, output_attentions, step)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
outputs = torch.utils.checkpoint.checkpoint(
|
outputs = torch.utils.checkpoint.checkpoint(
|
||||||
create_custom_forward(layer),
|
create_custom_forward(layer),
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
encoder_hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
outputs = layer(
|
outputs = layer(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
layer_past=layer_past,
|
layer_past=layer_past,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
step=step,
|
||||||
)
|
)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
if use_cache is True:
|
||||||
@ -577,9 +719,16 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self.gpt_neox = PythiaSeekModel(config)
|
self.pythiaseek = PythiaSeekModel(config)
|
||||||
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.encoder_dim = config.encoder_dim
|
||||||
|
|
||||||
|
if self.hidden_size != self.encoder_dim:
|
||||||
|
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.hidden_size * 4),
|
||||||
|
nn.Linear(config.hidden_size * 4, config.hidden_size))
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
@ -591,17 +740,19 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: torch.LongTensor,
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
step: Optional[int] = None
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||||
r"""
|
r"""
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
||||||
@ -644,8 +795,13 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
```"""
|
```"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
outputs = self.gpt_neox(
|
if self.hidden_size != self.encoder_dim:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||||
|
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
|
outputs = self.pythiaseek(
|
||||||
input_ids,
|
input_ids,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@ -655,6 +811,7 @@ class PythiaSeekForCausalLM(PythiaSeekPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
|
55
gpt4all/models/pythiaseek/test_pythiaseek.py
Normal file
55
gpt4all/models/pythiaseek/test_pythiaseek.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
from gpt4all.models import PythiaSeekConfig, PythiaSeekForCausalLM
|
||||||
|
from transformers import AutoTokenizer, AutoModel
|
||||||
|
|
||||||
|
# seed torch
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
config = PythiaSeekConfig(encoder_dim=384,
|
||||||
|
hidden_size=1024,
|
||||||
|
intermediate_size=4096,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
num_attention_heads=4)
|
||||||
|
print("loaded config")
|
||||||
|
|
||||||
|
print("loading model")
|
||||||
|
model = PythiaSeekForCausalLM(config)
|
||||||
|
print("loaded model")
|
||||||
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.model_max_length = 1024
|
||||||
|
|
||||||
|
encoder_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||||
|
encoder = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
|
||||||
|
|
||||||
|
|
||||||
|
def mean_pooling(model_output, attention_mask):
|
||||||
|
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
|
||||||
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||||
|
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
||||||
|
|
||||||
|
text = "The quick brown fox jumps over the lazy dog."
|
||||||
|
print("Encoded knn")
|
||||||
|
tokenized = encoder_tokenizer(text, return_tensors="pt")
|
||||||
|
|
||||||
|
# bs, seq_len, dim
|
||||||
|
encodings = mean_pooling(encoder(**tokenized), tokenized["attention_mask"])
|
||||||
|
|
||||||
|
# make 2 neighbors
|
||||||
|
# (bs, knn, encoding_dim)
|
||||||
|
encoder_hidden_states = torch.stack([encodings, encodings]).squeeze().unsqueeze(0)
|
||||||
|
|
||||||
|
inputs = "What did the fox do?"
|
||||||
|
|
||||||
|
print("Encoded inputs")
|
||||||
|
tokenized_input = tokenizer([inputs], padding="max_length", truncation=True, return_tensors="pt")
|
||||||
|
|
||||||
|
print("Running model")
|
||||||
|
outputs = model(**tokenized_input, encoder_hidden_states=encoder_hidden_states)
|
||||||
|
|
||||||
|
print(outputs)
|
||||||
|
print(outputs[0].shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user