mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
add kto
This commit is contained in:
@@ -5,6 +5,7 @@ loss functions
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
@@ -201,7 +202,79 @@ class OddsRatioLoss(nn.Module):
|
||||
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
|
||||
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
|
||||
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
|
||||
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
|
||||
log_odds_ratio = chosen_odds_masked - reject_odds_masked
|
||||
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
|
||||
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
|
||||
|
||||
|
||||
class KTOLoss(nn.Module):
|
||||
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the KTO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.desirable_weight = desirable_weight
|
||||
self.undesirable_weight = undesirable_weight
|
||||
|
||||
def forward(
|
||||
self,
|
||||
chosen_logps: torch.Tensor,
|
||||
rejected_logps: torch.Tensor,
|
||||
kl_logps: torch.Tensor,
|
||||
ref_chosen_logps: torch.Tensor,
|
||||
ref_rejected_logps: torch.Tensor,
|
||||
ref_kl_logps: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Reference:
|
||||
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
|
||||
|
||||
Compute the KTO loss for a batch of policy and reference model log probabilities.
|
||||
Args:
|
||||
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
|
||||
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
desirable_weight: The weight for the desirable responses.
|
||||
undesirable_weight: The weight for the undesirable responses.
|
||||
|
||||
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
|
||||
"""
|
||||
kl = (kl_logps - ref_kl_logps).mean().detach()
|
||||
# all gather
|
||||
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
|
||||
kl = (kl / dist.get_world_size()).clamp(min=0)
|
||||
# kl = 0
|
||||
|
||||
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
|
||||
chosen_logratios = chosen_logps - ref_chosen_logps
|
||||
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
|
||||
chosen_rewards = self.beta * chosen_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
chosen_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
|
||||
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
|
||||
rejected_logratios = rejected_logps - ref_rejected_logps
|
||||
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
|
||||
rejected_rewards = self.beta * rejected_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
rejected_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
|
||||
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, kl
|
||||
|
Reference in New Issue
Block a user