mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
fix eval
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Orpo trainer
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
@@ -269,11 +270,10 @@ class ORPOTrainer(SLTrainer):
|
||||
batch_size = chosen_input_ids.size()[0]
|
||||
actor_out = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
labels=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
actor_all_logits = actor_out["logits"].to(torch.float32)
|
||||
chosen_nll = torch.mean(actor_out["loss"][:batch_size]).to(dtype=torch.bfloat16)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
@@ -283,14 +283,22 @@ class ORPOTrainer(SLTrainer):
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
)
|
||||
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(logprob_actor_chosen, logprob_actor_reject)
|
||||
|
||||
chosen_logits = actor_chosen_logits[:, :-1, :].contiguous().view(-1, actor_chosen_logits.size(-1))
|
||||
label_chosen = chosen_input_ids[:, 1:].contiguous()
|
||||
label_chosen_masked = (
|
||||
label_chosen.masked_fill(chosen_loss_mask[:, 1:] == 0, -100).view(-1).contiguous().detach()
|
||||
)
|
||||
# label_chosen[chosen_loss_mask[:, 1:] == 0] = -100
|
||||
chosen_nll = self.sft_loss_fn(chosen_logits, label_chosen_masked).to(dtype=torch.bfloat16)
|
||||
odds_ratio_loss, log_odds_ratio = self.odds_ratio_loss_fn(
|
||||
logprob_actor_chosen, logprob_actor_reject, chosen_loss_mask[:, 1:], reject_loss_mask[:, 1:]
|
||||
)
|
||||
loss = chosen_nll - odds_ratio_loss * self.lam
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
chosen_rewards = torch.mean(logprob_actor_chosen).item()
|
||||
rejected_rewards = torch.mean(logprob_actor_reject).item()
|
||||
reward_accuracies = (log_odds_ratio > 0).float().mean().item()
|
||||
chosen_rewards = torch.sum(logprob_actor_chosen) / torch.sum(chosen_loss_mask[:, 1:])
|
||||
rejected_rewards = torch.sum(logprob_actor_reject) / torch.sum(reject_loss_mask[:, 1:])
|
||||
reward_accuracies = torch.sum((log_odds_ratio > 0).float()) / torch.sum(log_odds_ratio != 0)
|
||||
|
||||
# sync
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
@@ -303,37 +311,11 @@ class ORPOTrainer(SLTrainer):
|
||||
self.accumulative_meter.add("log_odds_ratio", log_odds_ratio.to(torch.float16).mean().item())
|
||||
self.accumulative_meter.add("accuracy", reward_accuracies_mean.to(torch.float16).item())
|
||||
|
||||
# logging
|
||||
if self.writer and is_rank_0():
|
||||
self.writer.add_scalar("eval/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.num_train_step)
|
||||
self.writer.add_scalar(
|
||||
"train/chosen_rewards", self.accumulative_meter.get("chosen_rewards"), self.num_train_step
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/rejected_rewards",
|
||||
self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log",
|
||||
self.accumulative_meter.get("chosen_rewards") - self.accumulative_meter.get("rejected_rewards"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/accuracy",
|
||||
self.accumulative_meter.get("accuracy"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.writer.add_scalar(
|
||||
"train/log_odds_ratio",
|
||||
self.accumulative_meter.get("log_odds_ratio"),
|
||||
self.num_train_step,
|
||||
)
|
||||
self.step_bar.update()
|
||||
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss", "chosen_rewards", "rejected_rewards", "log_odds_ratio", "accuracy"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
Reference in New Issue
Block a user