This commit is contained in:
YeAnbang
2024-07-18 07:54:11 +00:00
parent b3594d4d68
commit 09d5ffca1a
27 changed files with 1739 additions and 63 deletions

View File

@@ -5,6 +5,7 @@ loss functions
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from .utils import masked_mean
@@ -201,7 +202,79 @@ class OddsRatioLoss(nn.Module):
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
log_odds_ratio = chosen_odds_masked - reject_odds_masked
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
class KTOLoss(nn.Module):
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
"""
Args:
beta: The temperature parameter in the KTO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable
"""
super().__init__()
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def forward(
self,
chosen_logps: torch.Tensor,
rejected_logps: torch.Tensor,
kl_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
ref_kl_logps: torch.Tensor,
):
"""
Reference:
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
Compute the KTO loss for a batch of policy and reference model log probabilities.
Args:
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
beta: The temperature parameter in the DPO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable responses.
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
"""
kl = (kl_logps - ref_kl_logps).mean().detach()
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
# kl = 0
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
chosen_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
rejected_logratios = rejected_logps - ref_rejected_logps
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
rejected_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, kl