Merge pull request #5901 from hpcaitech/colossalchat

[Chat] fix eval: add in training evaluation, fix orpo sft loss bug
This commit is contained in:
YeAnbang 2024-07-16 11:07:32 +08:00 committed by GitHub
commit d8bf7e09a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 214 additions and 50 deletions

View File

@ -529,7 +529,7 @@ Coati is developed by ColossalAI Team:
- [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT. - [Fazzie](https://fazzie-key.cool/about/index.html) Contributing to the algorithm and development for SFT.
- [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development. - [ofey404](https://github.com/ofey404) Contributing to both front-end and back-end development.
- [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements. - [Wenhao Chen](https://github.com/CWHer) Contributing to subsequent code enhancements and performance improvements.
- [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO. - [Anbang Ye](https://github.com/YeAnbang) Contributing to the refactored PPO version with updated acceleration framework. Add support for DPO, SimPO, ORPO.
The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project. The PhD student from [(HPC-AI) Lab](https://ai.comp.nus.edu.sg/) also contributed a lot to this project.
- [Zangwei Zheng](https://github.com/zhengzangw) - [Zangwei Zheng](https://github.com/zhengzangw)
@ -579,6 +579,36 @@ We also appreciate the valuable suggestions provided by [Jian Hu](https://github
journal = {GitHub repository}, journal = {GitHub repository},
howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}}, howpublished = {\url{https://github.com/XueFuzhao/InstructionWild}},
} }
@misc{meng2024simposimplepreferenceoptimization,
title={SimPO: Simple Preference Optimization with a Reference-Free Reward},
author={Yu Meng and Mengzhou Xia and Danqi Chen},
year={2024},
eprint={2405.14734},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2405.14734},
}
@misc{rafailov2023directpreferenceoptimizationlanguage,
title={Direct Preference Optimization: Your Language Model is Secretly a Reward Model},
author={Rafael Rafailov and Archit Sharma and Eric Mitchell and Stefano Ermon and Christopher D. Manning and Chelsea Finn},
year={2023},
eprint={2305.18290},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2305.18290},
}
@misc{hong2024orpomonolithicpreferenceoptimization,
title={ORPO: Monolithic Preference Optimization without Reference Model},
author={Jiwoo Hong and Noah Lee and James Thorne},
year={2024},
eprint={2403.07691},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2403.07691},
}
``` ```
## Licenses ## Licenses

View File

@ -28,6 +28,8 @@ def load_tokenized_dataset(
Each instance of dataset is a dictionary with Each instance of dataset is a dictionary with
`{'input_ids': List[int], 'labels': List[int], sequence: str}` format. `{'input_ids': List[int], 'labels': List[int], sequence: str}` format.
""" """
if not dataset_paths:
return None
mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"}) mode_map = kwargs.get("mode_map", {"train": "train", "dev": "validation", "test": "test"})
assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}" assert mode in tuple(mode_map), f"Unsupported mode {mode}, it must be in {tuple(mode_map)}"

View File

@ -2,6 +2,7 @@
Dpo trainer Dpo trainer
""" """
import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -324,7 +325,7 @@ class DPOTrainer(SLTrainer):
chosen_loss_mask[:, 1:], chosen_loss_mask[:, 1:],
reject_loss_mask[:, 1:], reject_loss_mask[:, 1:],
) )
reward_accuracies = (chosen_rewards > rejected_rewards).float() reward_accuracies = (chosen_rewards > rejected_rewards).float().mean()
loss = losses.mean() loss = losses.mean()
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards) chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards)
@ -343,4 +344,7 @@ class DPOTrainer(SLTrainer):
for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]: for tag in ["loss", "chosen_rewards", "rejected_rewards", "accuracy", "margin"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg) self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close() step_bar.close()

View File

@ -2,6 +2,7 @@
Orpo trainer Orpo trainer
""" """
import os
from typing import Any, Optional from typing import Any, Optional
import torch import torch
@ -9,7 +10,6 @@ from coati.models.loss import OddsRatioLoss
from coati.models.utils import calc_masked_log_probs from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.nn import CrossEntropyLoss
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -62,7 +62,6 @@ class ORPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss() self.odds_ratio_loss_fn = OddsRatioLoss()
self.sft_loss_fn = CrossEntropyLoss()
self.save_interval = save_interval self.save_interval = save_interval
self.coordinator = coordinator self.coordinator = coordinator
self.save_dir = save_dir self.save_dir = save_dir
@ -135,6 +134,9 @@ class ORPOTrainer(SLTrainer):
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
) )
torch.autograd.set_detect_anomaly(True) torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32) actor_all_logits = actor_out["logits"].to(torch.float32)
@ -143,13 +145,8 @@ class ORPOTrainer(SLTrainer):
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]) logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]) logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
label_chosen = chosen_input_ids[:, 1:].contiguous()
label_chosen_masked = (
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
)
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100 # label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16) chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn( odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:] logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
) )
@ -269,11 +266,13 @@ class ORPOTrainer(SLTrainer):
batch_size = chosen_input_ids.size()[0] batch_size = chosen_input_ids.size()[0]
actor_out = self.model( actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]), input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
labels=torch.cat([chosen_input_ids, reject_input_ids]),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]), attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
labels=torch.cat(
[chosen_input_ids, torch.ones_like(reject_input_ids, dtype=reject_input_ids.dtype) * -100]
),
) )
torch.autograd.set_detect_anomaly(True)
actor_all_logits = actor_out["logits"].to(torch.float32) actor_all_logits = actor_out["logits"].to(torch.float32)
chosen_nll = torch.mean(actor_out["loss"][:batch_size]).to(dtype=torch.bfloat16)
actor_chosen_logits = actor_all_logits[:batch_size] actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:] actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs( logprob_actor_chosen = calc_masked_log_probs(
@ -283,14 +282,16 @@ class ORPOTrainer(SLTrainer):
logprob_actor_reject = calc_masked_log_probs( logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:] actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
) )
chosen_nll = actor_out["loss"]
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(logprob_actor_chosen, logprob_actor_reject) odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
)
loss = chosen_nll - odds_ratio_loss * self.lam loss = chosen_nll - odds_ratio_loss * self.lam
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
chosen_rewards = torch.mean(logprob_actor_chosen).item() chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
rejected_rewards = torch.mean(logprob_actor_reject).item() rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
reward_accuracies = (log_odds_ratio > 0).float().mean().item() reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
# sync # sync
loss_mean = all_reduce_mean(tensor=loss) loss_mean = all_reduce_mean(tensor=loss)
@ -303,37 +304,11 @@ class ORPOTrainer(SLTrainer):
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item()) self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item()) self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
# logging
if self.writer and is_rank_0():
self.writer.add_scalar("eval/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/rejected_rewards",
self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/log",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/accuracy",
self.accumulative_meter.get("accuracy"),
self.num_train_step,
)
self.writer.add_scalar(
"train/log_odds_ratio",
self.accumulative_meter.get("log_odds_ratio"),
self.num_train_step,
)
self.step_bar.update()
msg = "Evaluation Result:\n" msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]: for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg) self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close() step_bar.close()

View File

@ -237,6 +237,7 @@ class RewardModelTrainer(SLTrainer):
+ f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n" + f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
) )
self.coordinator.print_on_master(msg) self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg) f.write(msg)
step_bar.close() step_bar.close()

View File

@ -167,6 +167,7 @@ class SFTTrainer(SLTrainer):
for tag in ["loss"]: for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg) self.coordinator.print_on_master(msg)
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg) f.write(msg)
step_bar.close() step_bar.close()

View File

@ -176,6 +176,21 @@ def train(args):
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None: if args.warmup_steps is None:
@ -260,7 +275,7 @@ def train(args):
trainer.fit( trainer.fit(
train_preference_dataloader=train_dataloader, train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None, eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
@ -309,6 +324,7 @@ if __name__ == "__main__":
parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[]) parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument( parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
) )

View File

@ -164,6 +164,21 @@ def train(args):
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
if args.warmup_steps is None: if args.warmup_steps is None:
args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps)) args.warmup_steps = int(args.max_epochs * 0.025 * (len(train_dataloader) // args.accumulation_steps))
@ -242,7 +257,7 @@ def train(args):
trainer.fit( trainer.fit(
train_preference_dataloader=train_dataloader, train_preference_dataloader=train_dataloader,
eval_preference_dataloader=None, eval_preference_dataloader=eval_dataloader,
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
@ -288,6 +303,7 @@ if __name__ == "__main__":
parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--model_type", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[]) parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument( parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
) )

View File

@ -16,10 +16,13 @@ import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
logger = get_dist_logger()
def train(args): def train(args):
# check lora compatibility # check lora compatibility
@ -173,6 +176,22 @@ def train(args):
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=args.max_length)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
math.ceil(args.max_epochs * num_update_steps_per_epoch) math.ceil(args.max_epochs * num_update_steps_per_epoch)
@ -297,6 +316,7 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[]) parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument( parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
) )

View File

@ -173,6 +173,23 @@ def train(args):
collate_fn=data_collator, collate_fn=data_collator,
distributed_sampler_cls=StatefulDistributedSampler, distributed_sampler_cls=StatefulDistributedSampler,
) )
eval_dataloader = None
if args.eval_dataset:
eval_dataset = load_tokenized_dataset(dataset_paths=args.eval_dataset, mode="dev")
eval_data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=args.max_len)
eval_dataloader = plugin.prepare_dataloader(
dataset=eval_dataset,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=eval_data_collator,
distributed_sampler_cls=StatefulDistributedSampler,
)
else:
logger.warning("No evaluation dataset is provided, skip evaluation")
coordinator.print_on_master( coordinator.print_on_master(
f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB" f"Max CUDA memory after data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
) )
@ -255,7 +272,7 @@ def train(args):
trainer.fit( trainer.fit(
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
eval_dataloader=None, eval_dataloader=eval_dataloader,
log_dir=args.log_dir, log_dir=args.log_dir,
use_wandb=args.use_wandb, use_wandb=args.use_wandb,
) )
@ -300,6 +317,7 @@ if __name__ == "__main__":
parser.add_argument("--pretrain", type=str, default=None) parser.add_argument("--pretrain", type=str, default=None)
parser.add_argument("--tokenizer_dir", type=str, default=None) parser.add_argument("--tokenizer_dir", type=str, default=None)
parser.add_argument("--dataset", nargs="+", default=[]) parser.add_argument("--dataset", nargs="+", default=[])
parser.add_argument("--eval_dataset", nargs="+", default=[])
parser.add_argument( parser.add_argument(
"--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint" "--checkpoint_path", type=str, default=None, help="Checkpoint path if need to resume training form a checkpoint"
) )

View File

@ -173,6 +173,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \ --save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ --lora_rank $lora_rank \
@ -248,6 +249,7 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ --lora_rank $lora_rank \
@ -423,6 +425,85 @@ for lora_rank in ${LORA_RANK[@]}; do
--pretrain $pretrain \ --pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \ --tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \ --dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
passed=$?
if [ $passed -eq 0 ]; then
rm -rf $MODEL_SAVE_PATH/*
rm -rf $MODELS_DIR/*
break
fi
done
if [ $passed -ne 0 ]; then
echo "[Test]: Failed $model-$plugin-$lora_rank"
exit 1
fi
done
done
done
echo "[Test]: testing ORPO ..."
SKIPPED_TESTS=(
llama-3d-20 # 3d plugin doesn't support lora
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora
)
GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do
for plugin in ${PLUGINS[@]}; do
if [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin-$lora_rank " ]]; then
echo "[Test]: Skipped $model-$plugin-$lora_rank"
continue
elif [[ " ${SKIPPED_TESTS[*]} " =~ " $model-$plugin " ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
pretrain=$(get_pretrain $model)
tokenizer_dir=$(get_tokenizer_dirs $model)
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
tp='1'
bs='2'
if [[ $plugin == "3d" ]]; then
tp='4'
bs='8'
fi
grad_accu='2'
# gemini_auto and gemini doesn't support gradient accumulation
if [[ $plugin == "gemini_auto" ]]; then
grad_accu='1'
fi
# gemini_auto doesn't support generation
# (need to calculate ref_model logits through forwarding in inference mode)
if [[ $plugin == "gemini_auto" ]]; then
echo "[Test]: Skipped $model-$plugin"
continue
fi
for i in $(seq $NUM_RETRY); do
echo "[Test]: $model-$plugin-$lora_rank, attempt $i"
declare -a dataset=()
for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_preference/arrow/part-$split")
done
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_orpo.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_dir $MODEL_SAVE_PATH \ --save_dir $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \ --config_file $MODELS_DIR/config.jsonl \
--lora_rank $lora_rank \ --lora_rank $lora_rank \