diff --git a/train.py b/train.py index 463dd155..72e53f4d 100644 --- a/train.py +++ b/train.py @@ -1,8 +1,6 @@ import os -from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler -from transformers.trainer_pt_utils import get_parameter_names +from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM import torch -import torch.nn as nn from torch.optim import AdamW from argparse import ArgumentParser from read import read_config @@ -45,7 +43,7 @@ def train(accelerator, config): accelerator.print(f"Using {accelerator.num_processes} GPUs") tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length']) - # llama has no pad token, set it to new token + # if no pad token, set it to eos if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token @@ -76,21 +74,9 @@ def train(accelerator, config): else DummyOptim ) - no_decay = ["bias", "LayerNorm.weight"] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], - "weight_decay": config["weight_decay"], - }, - { - "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], - "weight_decay": 0.0, - }, - ] - # karpathy doesn't decay embeddding, maybe we should exclude # https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s - optimizer = optimizer_cls(optimizer_grouped_parameters, lr=config["lr"]) + optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"]) if accelerator.state.deepspeed_plugin is not None: gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[