diff --git a/gpt4all/models/configuration_gpt_jr.py b/gpt4all/models/configuration_gpt_jr.py index 314a87af..b5bcebb3 100644 --- a/gpt4all/models/configuration_gpt_jr.py +++ b/gpt4all/models/configuration_gpt_jr.py @@ -116,7 +116,9 @@ class GPTJRConfig(PretrainedConfig): eos_token_id=50256, tie_word_embeddings=False, encoder_dim=4096, - encoder_path=None, + total_alpha_steps=0, + initial_alpha=1, + final_alpha=.5, **kwargs ): self.vocab_size = vocab_size @@ -140,6 +142,10 @@ class GPTJRConfig(PretrainedConfig): self.encoder_dim = encoder_dim + self.total_alpha_steps = total_alpha_steps + self.initial_alpha = initial_alpha + self.final_alpha = final_alpha + super().__init__( bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs ) diff --git a/gpt4all/models/modeling_gpt_jr.py b/gpt4all/models/modeling_gpt_jr.py index 36e053ea..894a8e81 100644 --- a/gpt4all/models/modeling_gpt_jr.py +++ b/gpt4all/models/modeling_gpt_jr.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple, Union +import math import torch import torch.utils.checkpoint from torch import nn @@ -286,8 +287,8 @@ class GPTJRCrossAttention(GPTJRAttention): ) self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) - self.k_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False) - self.v_proj = nn.Linear(config.encoder_dim, self.embed_dim, bias=False) + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) @@ -419,11 +420,13 @@ class GPTJRBlock(nn.Module): self.attn = GPTJRAttention(config) self.mlp = GPTJRMLP(inner_dim, config) - self.ln_2 = nn.LayerNorm(config.encoder_dim, eps=config.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.cross_attn = GPTJRCrossAttention(config) self.cross_attn_mlp = GPTJRMLP(inner_dim, config) - self.world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else torch.cuda.device_count() or 1 + self.total_alpha_steps = config.total_alpha_steps + self.initial_alpha = config.initial_alpha + self.final_alpha = config.final_alpha def forward( self, @@ -473,13 +476,14 @@ class GPTJRBlock(nn.Module): cross_attn_output[0] ) - if step is not None: - alpha = self._update_alpha(step) - alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) - else: - alpha = 0.5 + # if step is not None: + # alpha = self._update_alpha(step) + # #alpha = alpha.to(cross_attn_ff.device).to(cross_attn_ff.dtype) + # else: + # alpha = 0.5 - hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual + # hidden_states = (1 - alpha) * cross_attn_ff + alpha * self_attention_residual + hidden_states = cross_attn_ff + self_attention_residual if use_cache: outputs = (hidden_states,) + outputs else: @@ -487,8 +491,35 @@ class GPTJRBlock(nn.Module): return outputs # hidden_states, present, (attentions) - def _update_alpha(self, iteration): - return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) + # def _update_alpha(self, iteration): + # return torch.clamp(torch.tensor([1 / (max(iteration * self.world_size, 1)) ** 0.08]), min=torch.tensor([0.5]), max=torch.tensor([1.0])) + + + def _update_alpha(self, current_step): + """ + Computes the learning rate for the current step using a cosine decay schedule. + + Args: + initial_lr (float): The initial learning rate. + final_lr (float): The final learning rate. + total_steps (int): The total number of steps in the schedule. + current_step (int): The current step. + + Returns: + float: The learning rate for the current step. + """ + initial_alpha = 1 + final_alpha = .5 + if current_step >= self.total_alpha_steps: + return final_alpha + + # Compute the cosine decay factor + cosine_decay = 0.5 * (1 + math.cos(math.pi * current_step / self.total_alpha_steps)) + + # Compute the current learning rate + alpha = final_alpha + (initial_alpha - final_alpha) * cosine_decay + + return alpha class GPTJRPreTrainedModel(PreTrainedModel): @@ -766,6 +797,12 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): self.transformer = GPTJRModel(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size) + self.hidden_size = config.hidden_size + self.encoder_dim = config.encoder_dim + + if self.hidden_size != self.encoder_dim: + self.enc_dec_proj = nn.Linear(config.encoder_dim, config.n_embd) + # Model parallel self.model_parallel = False self.device_map = None @@ -851,6 +888,10 @@ class GPTJRForCausalLM(GPTJRPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.hidden_size != self.encoder_dim: + encoder_hidden_states = encoder_hidden_states.to(self.enc_dec_proj.weight.dtype) + encoder_hidden_states = self.enc_dec_proj(encoder_hidden_states) + transformer_outputs = self.transformer( input_ids, encoder_hidden_states=encoder_hidden_states, diff --git a/gpt4all/train/train_retrieval.py b/gpt4all/train/train_retrieval.py index 1e31f01d..2e4dfc2e 100644 --- a/gpt4all/train/train_retrieval.py +++ b/gpt4all/train/train_retrieval.py @@ -1,5 +1,5 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler import torch from torch.optim import AdamW from argparse import ArgumentParser @@ -57,6 +57,15 @@ def train(accelerator, config): with accelerator.main_process_first(): train_dataloader, val_dataloader = load_retrieval_augmented_data(config, tokenizer) + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + + accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") + total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] + # instead of decaying to zero, decay to ratio of min_lr / lr + accelerator.print(f"Total training steps: {total_num_steps}") checkpoint = config["gradient_checkpointing"] #ensures back compat with non retrieval models @@ -66,6 +75,7 @@ def train(accelerator, config): revision=config['version'] if 'version' in config else None, use_cache=False if checkpoint else True, encoder_dim=config["encoder_dim"], + total_alpha_steps=total_num_steps ) else: model = AutoModelForCausalLM.from_pretrained(config["model_name"], @@ -94,18 +104,6 @@ def train(accelerator, config): # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) - if accelerator.state.deepspeed_plugin is not None: - gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ - "gradient_accumulation_steps" - ] - - # decay to min_lr instead of 0 - lr_ratio = config["min_lr"] / config["lr"] - accelerator.print(f"Len of train_dataloader: {len(train_dataloader)}") - total_num_steps = (len(train_dataloader) / gradient_accumulation_steps) * config["num_epochs"] - # instead of decaying to zero, decay to ratio of min_lr / lr - total_num_steps += int(total_num_steps * lr_ratio) + config["warmup_steps"] - accelerator.print(f"Total training steps: {total_num_steps}") # Creates Dummy Scheduler if `scheduler` was spcified in the config file else creates `args.lr_scheduler_type` Scheduler if ( @@ -206,13 +204,14 @@ def train(accelerator, config): accelerator.print(f"Pushing to HF hub") accelerator.wait_for_everyone() unwrapped_model = accelerator.unwrap_model(model) - try: - if accelerator.is_main_process: - unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) + if config["push_to_hub"]: + try: + if accelerator.is_main_process: + unwrapped_model.push_to_hub(config["save_name"] + f"-epoch_{epoch}", private=True) - except Exception as e: - accelerator.print(e) - accelerator.print(f"Failed to push to hub") + except Exception as e: + accelerator.print(e) + accelerator.print(f"Failed to push to hub") unwrapped_model.save_pretrained( f"{config['output_dir']}/epoch_{epoch}",