mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-07 20:41:24 +00:00
feat: cosine alpha schedule
This commit is contained in:
parent
27a9b2b10c
commit
bd6e471555
@ -476,14 +476,13 @@ class GPTJRBlock(nn.Module):
|
||||
cross_attn_output[0]
|
||||
)
|
||||
|
||||
# if step is not None:
|
||||
# alpha = self._update_alpha(step)
|
||||
# #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype)
|
||||
# else:
|
||||
# alpha = 0.5
|
||||
if step is not None:
|
||||
alpha = self._update_alpha(step)
|
||||
else:
|
||||
alpha = 0.5
|
||||
|
||||
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
|
||||
# hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
|
||||
hidden_states = cross_attn_ff + self_attention_residual
|
||||
if use_cache:
|
||||
outputs = (hidden_states,) + outputs
|
||||
else:
|
||||
@ -491,10 +490,6 @@ class GPTJRBlock(nn.Module):
|
||||
|
||||
return outputs # hidden_states, present, (attentions)
|
||||
|
||||
# def _update_alpha(self, iteration):
|
||||
# return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
|
||||
|
||||
|
||||
def _update_alpha(self, current_step):
|
||||
"""
|
||||
Computes the learning rate for the current step using a cosine decay schedule.
|
||||
@ -801,7 +796,8 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
||||
self.encoder_dim = config.encoder_dim
|
||||
|
||||
if self.hidden_size != self.encoder_dim:
|
||||
self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd)
|
||||
self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.n_embd * 4),
|
||||
nn.Linear(config.n_embd * 4, config.n_embd))
|
||||
|
||||
# Model parallel
|
||||
self.model_parallel = False
|
||||
@ -889,7 +885,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if self.hidden_size != self.encoder_dim:
|
||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype)
|
||||
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
|
||||
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
|
Loading…
Reference in New Issue
Block a user