moupdate ci tests, st ci test cases passed, tp failed in generation for ppo, sp is buggy

This commit is contained in:
YeAnbang
2024-05-28 07:58:08 +00:00
parent 7e65b71815
commit 0b4a33548c
7 changed files with 355 additions and 91 deletions

View File

@@ -56,7 +56,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -64,7 +64,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -89,8 +89,8 @@ def train(args):
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@@ -180,7 +180,9 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@@ -300,6 +302,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])

View File

@@ -18,7 +18,6 @@ from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dr
from coati.trainer import PPOTrainer
from coati.utils import load_checkpoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from colossalai.shardformer.policies.auto_policy import get_autopolicy
import colossalai
from colossalai.booster import Booster
@@ -27,6 +26,7 @@ from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
@@ -52,7 +52,6 @@ def train(args):
# )
init_ctx = nullcontext()
booster_policy = None
with init_ctx:
if args.use_flash_attn:
actor = AutoModelForCausalLM.from_pretrained(
@@ -150,34 +149,6 @@ def train(args):
adamw_mode=True,
)
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = setup_distributed_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = setup_distributed_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
)
else:
train_pretrain_dataloader = None
if args.warmup_steps is None:
args.warmup_steps = int(0.025 * args.num_episodes)
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
@@ -212,7 +183,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -220,7 +191,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@@ -240,9 +211,15 @@ def train(args):
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=0,
pp_size=args.pp,
sp_size=args.sp,
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
custom_plugin = HybridParallelPlugin(
@@ -252,8 +229,8 @@ def train(args):
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@@ -265,6 +242,38 @@ def train(args):
if args.plugin != "3d":
custom_plugin = plugin
# configure dataset
coordinator.print_on_master(f"Load dataset: {args.prompt_dataset}")
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = setup_distributed_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = setup_distributed_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
else:
train_pretrain_dataloader = None
actor_booster = Booster(plugin=plugin)
ref_booster = Booster(plugin=plugin)
rm_booster = Booster(plugin=custom_plugin)
@@ -459,6 +468,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])

View File

@@ -15,8 +15,7 @@ from coati.dataset import (
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
from coati.trainer import RewardModelTrainer
from coati.utils import load_checkpoint
from transformers import AutoTokenizer, AutoConfig
from colossalai.shardformer.policies.auto_policy import get_autopolicy
from transformers import AutoConfig, AutoTokenizer
import colossalai
from colossalai.booster import Booster
@@ -24,6 +23,7 @@ from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLev
from colossalai.cluster import DistCoordinator
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy
def train(args):
@@ -47,7 +47,6 @@ def train(args):
# )
init_ctx = nullcontext()
booster_policy = None
with init_ctx:
if args.use_flash_attn:
model = RewardModel(
@@ -57,7 +56,7 @@ def train(args):
)
coordinator.print_on_master(msg="Flash-attention enabled successfully")
else:
model_config = AutoConfig.from_pretrained(args.pretrain)
AutoConfig.from_pretrained(args.pretrain)
model = RewardModel(
args.pretrain,
)
@@ -114,12 +113,12 @@ def train(args):
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
custom_policy=get_autopolicy(model.model)
custom_policy=get_autopolicy(model.model),
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -177,7 +176,9 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
@@ -297,6 +298,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])

View File

@@ -125,11 +125,12 @@ def train(args):
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.batch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@@ -194,7 +195,9 @@ def train(args):
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=args.tp,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
)
# print(len(train_dataloader))
# for batch in train_dataloader:
@@ -321,6 +324,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
parser.add_argument("--sp_mode", type=str, default="split_gather", choices=["split_gather", "ring", "all_to_all"])