feat: cosine alpha schedule

This commit is contained in:
Zach Nussbaum 2023-05-03 21:39:14 +00:00
parent 27a9b2b10c
commit bd6e471555

View File

@ -476,14 +476,13 @@ class GPTJRBlock(nn.Module):
cross_attn_output[0] cross_attn_output[0]
) )
# if step is not None: if step is not None:
# alpha = self._update_alpha(step) alpha = self._update_alpha(step)
# #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) else:
# else: alpha = 0.5
# alpha = 0.5
hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
# hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual
hidden_states = cross_attn_ff + self_attention_residual
if use_cache: if use_cache:
outputs = (hidden_states,) + outputs outputs = (hidden_states,) + outputs
else: else:
@ -491,10 +490,6 @@ class GPTJRBlock(nn.Module):
return outputs # hidden_states, present, (attentions) return outputs # hidden_states, present, (attentions)
# def _update_alpha(self, iteration):
# return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0]))
def _update_alpha(self, current_step): def _update_alpha(self, current_step):
""" """
Computes the learning rate for the current step using a cosine decay schedule. Computes the learning rate for the current step using a cosine decay schedule.
@ -801,7 +796,8 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
self.encoder_dim = config.encoder_dim self.encoder_dim = config.encoder_dim
if self.hidden_size != self.encoder_dim: if self.hidden_size != self.encoder_dim:
self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd) self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.n_embd * 4),
nn.Linear(config.n_embd * 4, config.n_embd))
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
@ -889,7 +885,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.hidden_size != self.encoder_dim: if self.hidden_size != self.encoder_dim:
encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype) encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype)
encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states) encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states)
transformer_outputs = self.transformer( transformer_outputs = self.transformer(