ColossalAI/applications/ColossalChat/coati/trainer/rm.py
YeAnbang df5e9c53cf
[ColossalChat] Update RLHF V2 (#5286)
* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>
2024-03-29 14:12:29 +08:00

243 lines
10 KiB
Python
Executable File

"""
Reward model trianer
"""
import os
from typing import Any, Callable, Optional
import torch
import tqdm
from coati.models import LogSigLoss
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device
from .base import SLTrainer
from .utils import is_rank_0, to_device
class RewardModelTrainer(SLTrainer):
"""
Trainer for PPO algorithm.
Args:
actor (Actor): the actor model in ppo algorithm
ref_model (Critic): the reference model in ppo algorithm
booster (Strategy): the strategy to use for training
actor_optim (Optimizer): the optimizer to use for actor model
actor_lr_scheduler (_LRScheduler): the lr scheduler to use for actor model
tokenizer (PreTrainedTokenizerBase): the tokenizer to use for encoding
max_epochs (int, defaults to 1): the max number of epochs to train
beta (float, defaults to 0.1): the beta parameter in dpo loss
accumulation_steps (int): the number of steps to accumulate gradients
start_epoch (int, defaults to 0): the start epoch, non-zero if resumed from a checkpoint
save_interval (int): the interval to save model checkpoints, default to 0, which means no checkpoint will be saved during trainning
save_dir (str): the directory to save checkpoints
coordinator (DistCoordinator): the coordinator to use for distributed logging
"""
def __init__(
self,
model: Any,
booster: Booster,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None,
max_epochs: int = 1,
beta: float = 0.1,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
save_dir: str = None,
coordinator: DistCoordinator = None,
) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
self.num_train_step = 0
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
def _before_fit(
self,
train_preference_dataloader: DataLoader = None,
eval_preference_dataloader: DataLoader = None,
log_dir: Optional[str] = None,
use_wandb: bool = False,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self.train_dataloader = train_preference_dataloader
self.eval_dataloader = eval_preference_dataloader
self.writer = None
if use_wandb and is_rank_0():
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
import wandb
self.wandb_run = wandb.init(project="Coati-rm", sync_tensorboard=True)
if log_dir is not None and is_rank_0():
import os
import time
from torch.utils.tensorboard import SummaryWriter
log_dir = os.path.join(log_dir, "rm")
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
self.writer = SummaryWriter(log_dir=log_dir)
def _train(self, epoch):
self.model.train()
step_bar = tqdm.trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
reject_input_ids,
reject_attention_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
)
batch_size = chosen_input_ids.size()[0]
# Concatenate for better parrallelism
reward = self.model(
torch.cat([chosen_input_ids, reject_input_ids], dim=0),
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask], dim=0),
)
chosen_reward = reward[:batch_size]
reject_reward = reward[batch_size:]
loss = self.loss_fn(chosen_reward, reject_reward).mean()
self.booster.backward(loss=loss, optimizer=self.optimizer)
accuracy = (chosen_reward > reject_reward).float()
# Sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
accuracy_mean = all_reduce_mean(tensor=accuracy)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.accumulative_meter.add("accuracy", accuracy_mean.mean().to(torch.float16).item())
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.actor_scheduler.step()
step_bar.update()
self.num_train_step += 1
# Logging
if self.writer and is_rank_0():
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
self.writer.add_scalar(
"train/dist",
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
self.num_train_step,
)
self.writer.add_scalar(
"train/reward_chosen", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
)
self.writer.add_scalar(
"train/reward_reject", self.accumulative_meter.get("rejected_rewards"), self.num_train_step
)
self.writer.add_scalar("train/acc", self.accumulative_meter.get("accuracy"), self.num_train_step)
self.accumulative_meter.reset()
# Save checkpoint
if self.save_interval > 0 and (self.num_train_step + 1) % self.save_interval == 0:
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.actor_scheduler,
epoch=epoch,
step=i + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {(i + 1)/self.accumulation_steps} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch):
if self.eval_dataloader is None:
self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation")
return
self.model.eval()
step_bar = tqdm.trange(
len(self.eval_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0()
)
with torch.no_grad():
for i, batch in enumerate(self.eval_dataloader):
batch = to_device(batch, self.device)
(
chosen_input_ids,
chosen_attention_mask,
reject_input_ids,
reject_attention_mask,
) = (
batch["chosen_input_ids"],
batch["chosen_attention_mask"],
batch["reject_input_ids"],
batch["reject_attention_mask"],
)
chosen_reward = self.model(chosen_input_ids, attention_mask=chosen_attention_mask)
reject_reward = self.model(reject_input_ids, attention_mask=reject_attention_mask)
loss = self.loss_fn(chosen_reward, reject_reward).mean()
# Sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_reward)
rejected_rewards_mean = all_reduce_mean(tensor=reject_reward)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
step_bar.update()
msg = "Evaluation Result:\n"
for tag in ["loss", "chosen_rewards", "rejected_rewards"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
msg = (
msg
+ f"distance: {self.accumulative_meter.get('chosen_rewards')-self.accumulative_meter.get('rejected_rewards')}\n"
)
self.coordinator.print_on_master(msg)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()