mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
fix eval
This commit is contained in:
@@ -176,6 +176,19 @@ def train(args):
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
@@ -260,7 +273,7 @@ def train(args):
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
@@ -309,6 +322,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
|
@@ -164,6 +164,19 @@ def train(args):
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
if args.warmup_steps is None:
|
||||
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
|
||||
@@ -242,7 +255,7 @@ def train(args):
|
||||
|
||||
trainer.fit(
|
||||
train_preference_dataloader=train_dataloader,
|
||||
eval_preference_dataloader=None,
|
||||
eval_preference_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
@@ -288,6 +301,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
|
@@ -173,6 +173,20 @@ def train(args):
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
||||
math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
||||
|
||||
@@ -297,6 +311,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
|
@@ -173,6 +173,21 @@ def train(args):
|
||||
collate_fn=data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
eval_dataloader = None
|
||||
if args.eval_dataset:
|
||||
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
|
||||
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
|
||||
|
||||
eval_dataloader = plugin.prepare_dataloader(
|
||||
dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
collate_fn=eval_data_collator,
|
||||
distributed_sampler_cls=StatefulDistributedSampler,
|
||||
)
|
||||
|
||||
coordinator.print_on_master(
|
||||
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||
)
|
||||
@@ -255,7 +270,7 @@ def train(args):
|
||||
|
||||
trainer.fit(
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=None,
|
||||
eval_dataloader=eval_dataloader,
|
||||
log_dir=args.log_dir,
|
||||
use_wandb=args.use_wandb,
|
||||
)
|
||||
@@ -300,6 +315,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pretrain", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_dir", type=str, default=None)
|
||||
parser.add_argument("--dataset", nargs="+", default=[])
|
||||
parser.add_argument("--eval_dataset", nargs="+", default=[])
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
|
||||
)
|
||||
|
Reference in New Issue
Block a user