diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index 894a8e81..e1c14cf9 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -476,14 +476,13 @@ class GPTJRBlock(nn.Module): cross_attn_output[0] ) - # if step is not None: - # alpha = self._update_alpha(step) - # #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) - # else: - # alpha = 0.5 + if step is not None: + alpha = self._update_alpha(step) + else: + alpha = 0.5 + + hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual - # hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual - hidden_states = cross_attn_ff + self_attention_residual if use_cache: outputs = (hidden_states,) + outputs else: @@ -491,10 +490,6 @@ class GPTJRBlock(nn.Module): return outputs # hidden_states, present, (attentions) - # def _update_alpha(self, iteration): - # return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) - - def _update_alpha(self, current_step): """ Computes the learning rate for the current step using a cosine decay schedule. @@ -801,7 +796,8 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): self.encoder_dim = config.encoder_dim if self.hidden_size != self.encoder_dim: - self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd) + self.enc_dec_proj = nn.Sequential(nn.Linear(config.encoder_dim, config.n_embd * 4), + nn.Linear(config.n_embd * 4, config.n_embd)) # Model parallel self.model_parallel = False @@ -889,7 +885,7 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.hidden_size != self.encoder_dim: - encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype) + encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj[0].weight.dtype) encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states) transformer_outputs = self.transformer(