upgrade ppo dpo rm script

This commit is contained in:
YeAnbang
2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@@ -18,6 +18,7 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@@ -86,32 +87,6 @@ def train(args):
disable_dropout(actor)
disable_dropout(critic)
if args.tp > 1:
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
raise ValueError("Reward model and critic model must have the same architecture")
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
@@ -186,7 +161,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
if len(args.ptx_dataset) > 0:
@@ -198,7 +173,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
else:
train_pretrain_dataloader = None
@@ -237,6 +212,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -244,6 +220,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -270,11 +247,17 @@ def train(args):
)
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=booster_policy,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -474,6 +457,11 @@ if __name__ == "__main__":
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)