This commit is contained in:
YeAnbang
2024-07-11 03:35:03 +00:00
parent 8a9721bafe
commit e7a8634636
10 changed files with 171 additions and 41 deletions

View File

@@ -176,6 +176,19 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
@@ -260,7 +273,7 @@ def train(args):
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
@@ -309,6 +322,7 @@ if __name__ == "__main__":
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)

View File

@@ -164,6 +164,19 @@ def train(args):
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
@@ -242,7 +255,7 @@ def train(args):
trainer.fit(
train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None,
eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
@@ -288,6 +301,7 @@ if __name__ == "__main__":
parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)

View File

@@ -173,6 +173,20 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch)
@@ -297,6 +311,7 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)

View File

@@ -173,6 +173,21 @@ def train(args):
collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
)
@@ -255,7 +270,7 @@ def train(args):
trainer.fit(
train_dataloader=train_dataloader,
eval_dataloader=None,
eval_dataloader=eval_dataloader,
log_dir=args.log_dir,
use_wandb=args.use_wandb,
)
@@ -300,6 +315,7 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
)