mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 13:00:52 +00:00
@@ -13,7 +13,7 @@ from coati.dataset import (
|
||||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
)
|
||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
||||
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -31,8 +31,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -81,20 +84,26 @@ def train(args):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
reward_model = RewardModel(args.rm_pretrain)
|
||||
critic = Critic(args.rm_pretrain)
|
||||
|
||||
if args.lora_config is not None:
|
||||
actor = convert_to_lora_module(actor, lora_config=lora_config)
|
||||
critic = convert_to_lora_module(critic, lora_config=lora_config)
|
||||
for name, module in actor.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
for name, module in critic.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
lora_manager.able_to_merge = False
|
||||
|
||||
# Disable dropout
|
||||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
actor.gradient_checkpointing_enable()
|
||||
critic.model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
@@ -421,11 +430,9 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
lora_manager.able_to_merge = True
|
||||
actor.eval()
|
||||
critic.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
@@ -484,11 +491,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_train_bias", type=str, default="none")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
|
Reference in New Issue
Block a user