mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 17:17:05 +00:00
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"r": 128,
|
||||
"embedding_lora_dropout": 0.0,
|
||||
"linear_lora_dropout": 0.1,
|
||||
"lora_alpha": 32,
|
||||
"lora_train_bias": "all",
|
||||
"lora_initialization_method": "PiSSA",
|
||||
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
|
||||
}
|
@@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import DPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -115,7 +118,7 @@ def train(args):
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -125,15 +128,19 @@ def train(args):
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(model)
|
||||
disable_dropout(ref_model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
@@ -280,11 +287,8 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
@@ -343,15 +347,8 @@ if __name__ == "__main__":
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
|
@@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import KTOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -115,7 +118,7 @@ def train(args):
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
@@ -124,13 +127,17 @@ def train(args):
|
||||
)
|
||||
else:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(ref_model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
disable_dropout(model)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
@@ -299,11 +306,8 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
@@ -355,15 +359,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--auto_weight", default=False, action="store_true")
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
|
@@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module, disable_dropout
|
||||
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||
from coati.trainer import ORPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -114,13 +117,16 @@ def train(args):
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
if args.lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
disable_dropout(model)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
@@ -262,11 +268,8 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
@@ -322,15 +325,8 @@ if __name__ == "__main__":
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
|
@@ -13,7 +13,7 @@ from coati.dataset import (
|
||||
load_tokenized_dataset,
|
||||
setup_conversation_template,
|
||||
)
|
||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
||||
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -31,8 +31,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -81,20 +84,26 @@ def train(args):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||
reward_model = RewardModel(args.rm_pretrain)
|
||||
critic = Critic(args.rm_pretrain)
|
||||
|
||||
if args.lora_config is not None:
|
||||
actor = convert_to_lora_module(actor, lora_config=lora_config)
|
||||
critic = convert_to_lora_module(critic, lora_config=lora_config)
|
||||
for name, module in actor.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
for name, module in critic.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
lora_manager.able_to_merge = False
|
||||
|
||||
# Disable dropout
|
||||
disable_dropout(actor)
|
||||
disable_dropout(critic)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
actor.gradient_checkpointing_enable()
|
||||
critic.model.gradient_checkpointing_enable()
|
||||
if args.grad_checkpoint:
|
||||
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
@@ -421,11 +430,9 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
lora_manager.able_to_merge = True
|
||||
actor.eval()
|
||||
critic.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
@@ -484,11 +491,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||
parser.add_argument("--lora_train_bias", type=str, default="none")
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=9e-6)
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
|
@@ -7,7 +7,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
@@ -25,8 +25,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -58,9 +61,11 @@ def train(args):
|
||||
args.pretrain,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
# ==============================
|
||||
# Initialize Booster
|
||||
# ==============================
|
||||
@@ -122,11 +127,9 @@ def train(args):
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
|
||||
if args.grad_checkpoint:
|
||||
model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
@@ -272,16 +275,13 @@ def train(args):
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_dir is not None:
|
||||
@@ -330,15 +330,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--log_dir", default=None, type=str)
|
||||
|
@@ -7,7 +7,7 @@ from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||
from coati.models import convert_to_lora_module
|
||||
from coati.models import LoraConfig, convert_to_lora_module
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
@@ -24,8 +24,11 @@ logger = get_dist_logger()
|
||||
|
||||
|
||||
def train(args):
|
||||
lora_config = None
|
||||
if args.lora_config is not None:
|
||||
lora_config = LoraConfig.from_file(args.lora_config)
|
||||
# check lora compatibility
|
||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
||||
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||
@@ -53,8 +56,12 @@ def train(args):
|
||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if lora_config is not None:
|
||||
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or "gate" in name:
|
||||
module = module.to(torch.float32)
|
||||
|
||||
if args.plugin == "ddp":
|
||||
"""
|
||||
@@ -114,6 +121,15 @@ def train(args):
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# ======================================================
|
||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||
# ======================================================
|
||||
@@ -124,7 +140,7 @@ def train(args):
|
||||
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
|
||||
# configure tokenizer
|
||||
@@ -149,15 +165,6 @@ def train(args):
|
||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||
|
||||
# configure optimizer
|
||||
optim = HybridAdam(
|
||||
model_params=model.parameters(),
|
||||
lr=args.lr,
|
||||
betas=(0.9, 0.95),
|
||||
weight_decay=args.weight_decay,
|
||||
adamw_mode=True,
|
||||
)
|
||||
|
||||
# configure dataset
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
@@ -217,6 +224,7 @@ def train(args):
|
||||
lr_scheduler=lr_scheduler,
|
||||
dataloader=train_dataloader,
|
||||
)
|
||||
|
||||
torch.set_default_dtype(torch.float)
|
||||
|
||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||
@@ -277,11 +285,8 @@ def train(args):
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
||||
from coati.models.lora import LORA_MANAGER
|
||||
|
||||
if lora_config is not None and lora_config.r > 0:
|
||||
# NOTE: set model to eval to merge LoRA weights
|
||||
LORA_MANAGER.merge_weights = True
|
||||
model.eval()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
if args.save_path is not None:
|
||||
@@ -328,15 +333,8 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--max_len", type=int, default=512)
|
||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
"--lora_train_bias",
|
||||
type=str,
|
||||
default="none",
|
||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
||||
)
|
||||
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
||||
parser.add_argument("--lr", type=float, default=5e-6)
|
||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||
|
@@ -21,16 +21,16 @@ PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||
declare -a dataset=(
|
||||
/Your/SFT/Data/arrow/part-00000
|
||||
/Your/SFT/Data/arrow/part-00001
|
||||
/Your/SFT/Data/arrow/part-00002
|
||||
/Your/SFT/Data/arrow/part-00003
|
||||
/Your/SFT/Data/arrow/part-00004
|
||||
/Your/SFT/Data/arrow/part-00005
|
||||
/Your/SFT/Data/arrow/part-00006
|
||||
/Your/SFT/Data/arrow/part-00007
|
||||
/Your/SFT/Data/arrow/part-00008
|
||||
/Your/SFT/Data/arrow/part-00009
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||
)
|
||||
|
||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||
@@ -47,15 +47,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
|
||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||
--save_interval 2000 \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
--lora_rank 0 \
|
||||
--plugin zero2 \
|
||||
--batch_size 8 \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps 2 \
|
||||
--accumulation_steps 1 \
|
||||
--lr 5e-5 \
|
||||
--max_len 4096 \
|
||||
--use_flash_attn \
|
||||
--grad_checkpoint \
|
||||
--use_flash_attn
|
||||
--save_path $SAVE_DIR \
|
||||
--config_file $CONFIG_FILE \
|
||||
--log_dir $LOG_DIR \
|
||||
|
Reference in New Issue
Block a user