replace the customized dataloader setup with the build-in one

This commit is contained in:
YeAnbang
2024-06-07 09:43:42 +00:00
parent 790e1362a6
commit 0d7ff10ea5
12 changed files with 79 additions and 218 deletions

View File

@@ -12,7 +12,6 @@ from coati.dataset import (
StatefulDistributedSampler,
load_tokenized_dataset,
setup_conversation_template,
setup_distributed_dataloader,
)
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
from coati.trainer import PPOTrainer
@@ -209,6 +208,9 @@ def train(args):
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
if args.use_flash_attn and (args.tp > 1 or args.pp > 1 or args.sp > 1 or args.enable_sequence_parallelism):
logger.warning("Flash attention cannot be used with 3D parallelism for PPO training. Disabling it.")
args.use_flash_attn = False
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=args.pp,
@@ -247,29 +249,26 @@ def train(args):
mode_map = {"train": "train", "valid": "validation", "test": "test"}
train_prompt_dataset = load_tokenized_dataset(dataset_paths=args.prompt_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForPromptDataset(tokenizer=tokenizer, max_length=args.max_length - args.max_seq_len)
train_prompt_dataloader = setup_distributed_dataloader(
train_prompt_dataloader = plugin.prepare_dataloader(
dataset=train_prompt_dataset,
batch_size=args.experience_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
if len(args.ptx_dataset) > 0:
train_ptx_dataset = load_tokenized_dataset(dataset_paths=args.ptx_dataset, mode="train", mode_map=mode_map)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_length)
train_pretrain_dataloader = setup_distributed_dataloader(
train_pretrain_dataloader = plugin.prepare_dataloader(
dataset=train_ptx_dataset,
batch_size=args.ptx_batch_size,
shuffle=True,
drop_last=True,
collate_fn=data_collator,
tp_size=plugin.tp_size if hasattr(plugin, "tp_size") else 1,
sp_size=plugin.sp_size if hasattr(plugin, "sp_size") else 1,
pp_size=plugin.pp_size if hasattr(plugin, "pp_size") else 1,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
train_pretrain_dataloader = None