mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 04:20:59 +00:00
feat: cosine alpha schedule
This commit is contained in:
parent
27a9b2b10c
commit
bd6e471555
@ -476,14 +476,13 @@ class GPTJRBlock(nn.Module):
|
|||||||
cross_attn_output[0]
|
cross_attn_output[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
# if step is not None:
|
if step is not None:
|
||||||
# alpha = self._update_alpha(step)
|
alpha = self._update_alpha(step)
|
||||||
# #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
|
else:
|
||||||
# else:
|
alpha = 0.5
|
||||||
# alpha = 0.5
|
|
||||||
|
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||||
|
|
||||||
# hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
|
||||||
hidden_states = cross_attn_ff + self_attention_residual
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
outputs = (hidden_states,) + outputs
|
outputs = (hidden_states,) + outputs
|
||||||
else:
|
else:
|
||||||
@ -491,10 +490,6 @@ class GPTJRBlock(nn.Module):
|
|||||||
|
|
||||||
return outputs # hidden_states, present, (attentions)
|
return outputs # hidden_states, present, (attentions)
|
||||||
|
|
||||||
# def _update_alpha(self, iteration):
|
|
||||||
# return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
|
|
||||||
|
|
||||||
|
|
||||||
def _update_alpha(self, current_step):
|
def _update_alpha(self, current_step):
|
||||||
"""
|
"""
|
||||||
Computes the learning rate for the current step using a cosine decay schedule.
|
Computes the learning rate for the current step using a cosine decay schedule.
|
||||||
@ -801,7 +796,8 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
self.encoder_dim = config.encoder_dim
|
self.encoder_dim = config.encoder_dim
|
||||||
|
|
||||||
if self.hidden_size != self.encoder_dim:
|
if self.hidden_size != self.encoder_dim:
|
||||||
self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd)
|
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.n_embd * 4),
|
||||||
|
nn.Linear(config.n_embd * 4, config.n_embd))
|
||||||
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
self.model_parallel = False
|
self.model_parallel = False
|
||||||
@ -889,7 +885,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if self.hidden_size != self.encoder_dim:
|
if self.hidden_size != self.encoder_dim:
|
||||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype)
|
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||||
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
|
Loading…
Reference in New Issue
Block a user