[Chat] Fix lora (#5946)

* fix merging

* remove filepath

* fix style
This commit is contained in:
YeAnbang
2024-07-31 14:10:17 +08:00
committed by GitHub
parent 09c5f72595
commit 30f4e31a33
13 changed files with 552 additions and 252 deletions

View File

@@ -7,7 +7,7 @@ from contextlib import nullcontext
import torch
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer
@@ -25,8 +25,11 @@ logger = get_dist_logger()
def train(args):
lora_config = None
if args.lora_config is not None:
lora_config = LoraConfig.from_file(args.lora_config)
# check lora compatibility
if "gemini" in args.plugin and args.lora_rank > 0:
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
@@ -58,9 +61,11 @@ def train(args):
args.pretrain,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if lora_config is not None:
model = convert_to_lora_module(model, lora_config=lora_config)
for name, module in model.named_modules():
if "norm" in name or "gate" in name:
module = module.to(torch.float32)
# ==============================
# Initialize Booster
# ==============================
@@ -122,11 +127,9 @@ def train(args):
booster = Booster(plugin=plugin)
if args.grad_checkpoint and args.lora_rank == 0:
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
if args.grad_checkpoint:
model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
@@ -272,16 +275,13 @@ def train(args):
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
if args.lora_rank > 0 and args.merge_lora_weights:
from coati.models.lora import LORA_MANAGER
if lora_config is not None and lora_config.r > 0:
# NOTE: set model to eval to merge LoRA weights
LORA_MANAGER.merge_weights = True
model.eval()
# save model checkpoint after fitting on only rank0
if args.save_dir is not None:
@@ -330,15 +330,8 @@ if __name__ == "__main__":
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(
"--lora_train_bias",
type=str,
default="none",
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
)
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
parser.add_argument("--merge_lora_weights", type=bool, default=True)
parser.add_argument("--lr", type=float, default=5e-6)
parser.add_argument("--accumulation_steps", type=int, default=8)
parser.add_argument("--log_dir", default=None, type=str)