mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
@@ -7,7 +7,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.models import LoraConfig, convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -24,8 +24,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -53,8 +56,12 @@ def train(args):
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
@@ -114,6 +121,15 @@ def train(args):
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
@@ -124,7 +140,7 @@ def train(args):
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
@@ -149,15 +165,6 @@ def train(args):
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
@@ -217,6 +224,7 @@ def train(args):
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
@@ -277,11 +285,8 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_path is not None:
|
||||
@@ -328,15 +333,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
|
Reference in New Issue
Block a user