upgrade ppo dpo rm script

This commit is contained in:
YeAnbang
2024-05-28 03:04:39 +00:00
parent 7a7e86987d
commit 929e1e3da4
15 changed files with 169 additions and 139 deletions

View File

@@ -56,6 +56,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -63,6 +64,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -82,9 +84,15 @@ def train(args):
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
@@ -172,7 +180,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@@ -290,6 +298,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)

View File

@@ -18,6 +18,7 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@@ -86,32 +87,6 @@ def train(args):
disable_dropout(actor)
disable_dropout(critic)
if args.tp > 1:
if reward_model.model.config.architectures[0] != critic.model.config.architectures[0]:
raise ValueError("Reward model and critic model must have the same architecture")
if reward_model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif reward_model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif reward_model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif reward_model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif reward_model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
if args.lora_rank > 0:
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
@@ -186,7 +161,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
if len(args.ptx_dataset) > 0:
@@ -198,7 +173,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
else:
train_pretrain_dataloader = None
@@ -237,6 +212,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -244,6 +220,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -270,11 +247,17 @@ def train(args):
)
custom_plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=booster_policy,
custom_policy=get_autopolicy(reward_model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -474,6 +457,11 @@ if __name__ == "__main__":
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--rm_pretrain", type=str, default=None)
parser.add_argument("--checkpoint_path", type=str, default=None)

View File

@@ -15,7 +15,8 @@ from coati.dataset import (
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoConfig
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@@ -56,31 +57,10 @@ def train(args):
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model = RewardModel(args.pretrain)
if args.tp > 1:
if model.model.config.architectures[0] == "BloomForCausalLM":
from colossalai.shardformer.policies.bloom import BloomPolicy
booster_policy = BloomPolicy()
elif model.model.config.architectures[0] == "LlamaForCausalLM":
from colossalai.shardformer.policies.llama import LlamaPolicy
booster_policy = LlamaPolicy()
elif model.model.config.architectures[0] == "GPT2LMHeadModel":
from colossalai.shardformer.policies.gpt2 import GPT2Policy
booster_policy = GPT2Policy()
elif model.model.config.architectures[0] == "ChatGLMModel":
from colossalai.shardformer.policies.chatglm2 import ChatGLMPolicy
booster_policy = ChatGLMPolicy()
elif model.model.config.architectures[0] == "OPTForCausalLM":
from colossalai.shardformer.policies.opt import OPTPolicy
booster_policy = OPTPolicy()
else:
raise ValueError("Unknown model architecture for policy")
model_config = AutoConfig.from_pretrained(args.pretrain)
model = RewardModel(
args.pretrain,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@@ -100,6 +80,7 @@ def train(args):
placement_policy="static",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn,
enable_gradient_accumulation=True,
)
elif args.plugin == "gemini_auto":
@@ -107,6 +88,7 @@ def train(args):
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
enable_flash_attention=args.use_flash_attn,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
@@ -127,11 +109,17 @@ def train(args):
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=booster_policy,
custom_policy=get_autopolicy(model.model)
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -189,7 +177,7 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
use_tp=args.tp > 1,
tp_size=args.tp,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@@ -307,6 +295,11 @@ if __name__ == "__main__":
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])

View File

@@ -48,29 +48,29 @@ def train(args):
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model
# try:
# from colossalai.shardformer.policies.auto_policy import get_autopolicy
# policy = get_autopolicy(model)
# if policy is not None:
# if args.plugin in ['zero2', 'zero2_cpu']:
# # if compatible, set the plugin to hybrid, which use colo-attention
# args.plugin = 'hybrid'
# args.zero_stage = 2
# if args.plugin == 'zero2_cpu':
# args.zero_cpu_offload = True
# else:
# args.zero_cpu_offload = False
# logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
# except NotImplementedError:
# logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
# if args.use_flash_attn:
# del model
# model = AutoModelForCausalLM.from_pretrained(
# args.pretrain,
# torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# attn_implementation="flash_attention_2",
# trust_remote_code=True
# )
try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model)
if policy is not None:
if args.plugin in ['zero2', 'zero2_cpu']:
# if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = '3d'
args.zero_stage = 2
if args.plugin == 'zero2_cpu':
args.zero_cpu_offload = True
else:
args.zero_cpu_offload = False
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn:
del model
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2",
trust_remote_code=True
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@@ -112,7 +112,7 @@ def train(args):
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "hybrid":
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@@ -224,7 +224,6 @@ def train(args):
lr_scheduler=lr_scheduler,
dataloader=train_dataloader,
)
# model = model.to(get_current_device())
torch.set_default_dtype(torch.float)
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
@@ -309,7 +308,7 @@ if __name__ == "__main__":
"--plugin",
type=str,
default="gemini",
choices=["gemini", "gemini_auto", "hybrid", "ddp", "zero2_cpu", "zero2"],
choices=["gemini", "gemini_auto", "3d", "ddp", "zero2_cpu", "zero2"],
help="Choose which plugin to use",
)
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")

View File

@@ -15,24 +15,24 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
# export CUDA_VISIBLE_DEVICES=4,5,6
set_n_least_used_CUDA_VISIBLE_DEVICES 4
set_n_least_used_CUDA_VISIBLE_DEVICES 2
PROJECT_NAME="sft"
PARENT_SAVE_DIR="/home/yeanbang/data/experiment/output/model" # Path to a folder to save checkpoints
PARENT_TENSORBOARD_DIR="/home/yeanbang/data/experiment/logs/tensorboard" # Path to a folder to save logs
PARENT_CONFIG_FILE="/home/yeanbang/data/experiment/logs/config" # Path to a folder to save training config logs
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Llama-2-7b-chat-hf" # huggingface or local tokenizer path
PRETRAINED_MODEL_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local model path
PRETRAINED_TOKENIZER_PATH="/mnt/jfs-hdd/share/models/Yi-1.5-6B" # huggingface or local tokenizer path
declare -a dataset=(
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Llama-2-7b-chat-hf/arrow/part-00009
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00000
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00001
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00002
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00003
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00004
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00005
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00006
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00007
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00008
/home/yeanbang/data/experiment/dataset/alpaca/test/Yi-1.5-6B/arrow/part-00009
)
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
@@ -43,7 +43,7 @@ CONFIG_FILE="${PARENT_CONFIG_FILE}-${FULL_PROJECT_NAME}.json"
echo $(which colossalai)
echo $(which python)
# the real batch size for gradient descent is number_of_node_in_hostfile * nproc_per_node * train_batch_size
colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile train_sft.py \
colossalai run --nproc_per_node 2 --master_port 31312 --hostfile ./hostfile train_sft.py \
--pretrain $PRETRAINED_MODEL_PATH \
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
--save_interval 4000 \
@@ -51,13 +51,13 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
--save_path $SAVE_DIR \
--config_file $CONFIG_FILE \
--lora_rank 0 \
--plugin zero2 \
--tp 1 \
--plugin 3d \
--tp 2 \
--pp 1 \
--zero_stage 2 \
--batch_size 4 \
--zero_stage 0 \
--batch_size 2 \
--max_epochs 3 \
--accumulation_steps 4 \
--accumulation_steps 1 \
--lr 5e-5 \
--max_len 400 \
--grad_checkpoint \