mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
upgrade ppo dpo rm script
This commit is contained in:
@@ -15,7 +15,8 @@ from coati.dataset import (
|
||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.utils import load_checkpoint
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
@@ -56,31 +57,10 @@ def train(args):
|
||||
)
|
||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||
else:
|
||||
model = RewardModel(args.pretrain)
|
||||
|
||||
if args.tp > 1:
|
||||
if model.model.config.architectures[0] == "BloomForCausalLM":
|
||||
from colossalai.shardformer.policies.bloom import BloomPolicy
|
||||
|
||||
booster_policy = BloomPolicy()
|
||||
elif model.model.config.architectures[0] == "LlamaForCausalLM":
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
booster_policy = LlamaPolicy()
|
||||
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
|
||||
from colossalai.shardformer.policies.gpt2 import GPT2Policy
|
||||
|
||||
booster_policy = GPT2Policy()
|
||||
elif model.model.config.architectures[0] == "ChatGLMModel":
|
||||
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
|
||||
|
||||
booster_policy = ChatGLMPolicy()
|
||||
elif model.model.config.architectures[0] == "OPTForCausalLM":
|
||||
from colossalai.shardformer.policies.opt import OPTPolicy
|
||||
|
||||
booster_policy = OPTPolicy()
|
||||
else:
|
||||
raise ValueError("Unknown model architecture for policy")
|
||||
model_config = AutoConfig.from_pretrained(args.pretrain)
|
||||
model = RewardModel(
|
||||
args.pretrain,
|
||||
)
|
||||
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
@@ -100,6 +80,7 @@ def train(args):
|
||||
placement_policy="static",
|
||||
initial_scale=2**16,
|
||||
max_norm=args.grad_clip,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_gradient_accumulation=True,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
@@ -107,6 +88,7 @@ def train(args):
|
||||
precision=args.mixed_precision,
|
||||
placement_policy="auto",
|
||||
initial_scale=2**16,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
max_norm=args.grad_clip,
|
||||
)
|
||||
elif args.plugin == "zero2":
|
||||
@@ -127,11 +109,17 @@ def train(args):
|
||||
elif args.plugin == "3d":
|
||||
plugin = HybridParallelPlugin(
|
||||
tp_size=args.tp,
|
||||
pp_size=1,
|
||||
zero_stage=0,
|
||||
pp_size=args.pp,
|
||||
sp_size=args.sp,
|
||||
sequence_parallelism_mode=args.sp_mode,
|
||||
zero_stage=args.zero_stage,
|
||||
enable_flash_attention=args.use_flash_attn,
|
||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
||||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
custom_policy=booster_policy,
|
||||
custom_policy=get_autopolicy(model.model)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
@@ -189,7 +177,7 @@ def train(args):
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=data_collator,
|
||||
use_tp=args.tp > 1,
|
||||
tp_size=args.tp,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
@@ -307,6 +295,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
||||
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
|
Reference in New Issue
Block a user