mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-06-27 07:48:19 +00:00
fix: eos/pad token + wd
This commit is contained in:
parent
c82ee7d882
commit
31195270cb
20
train.py
20
train.py
@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
|
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, LlamaForCausalLM
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from read import read_config
|
from read import read_config
|
||||||
@ -45,7 +43,7 @@ def train(accelerator, config):
|
|||||||
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
accelerator.print(f"Using {accelerator.num_processes} GPUs")
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
tokenizer = AutoTokenizer.from_pretrained(config['tokenizer_name'], model_max_length=config['max_length'])
|
||||||
# llama has no pad token, set it to new token
|
# if no pad token, set it to eos
|
||||||
if tokenizer.pad_token is None:
|
if tokenizer.pad_token is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
@ -76,21 +74,9 @@ def train(accelerator, config):
|
|||||||
else DummyOptim
|
else DummyOptim
|
||||||
)
|
)
|
||||||
|
|
||||||
no_decay = ["bias", "LayerNorm.weight"]
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": config["weight_decay"],
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
# karpathy doesn't decay embeddding, maybe we should exclude
|
# karpathy doesn't decay embeddding, maybe we should exclude
|
||||||
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
# https://github.com/karpathy/minGPT/commit/bbbdac74fa9b2e55574d70056163ffbae42310c1#diff-2075fa9c224b395be5bda85544dd36572b59c76c54562819eadadbf268602834R157s
|
||||||
optimizer = optimizer_cls(optimizer_grouped_parameters, lr=config["lr"])
|
optimizer = optimizer_cls(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
|
||||||
|
|
||||||
if accelerator.state.deepspeed_plugin is not None:
|
if accelerator.state.deepspeed_plugin is not None:
|
||||||
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||||
|
Loading…
Reference in New Issue
Block a user