This commit is contained in:
YeAnbang
2024-07-18 07:54:11 +00:00
parent b3594d4d68
commit 09d5ffca1a
27 changed files with 1739 additions and 63 deletions

View File

@@ -2,7 +2,7 @@ from .base import BaseModel
from .critic import Critic
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
from .lora import convert_to_lora_module
from .loss import DpoLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
from .reward_model import RewardModel
from .utils import disable_dropout
@@ -16,7 +16,7 @@ __all__ = [
"LogExpLoss",
"convert_to_lora_module",
"DpoLoss",
"generate",
"KTOLoss" "generate",
"generate_streaming",
"disable_dropout",
"update_model_kwargs_fn",

View File

@@ -42,7 +42,6 @@ class BaseModel(nn.Module):
out = self.model(dummy_input)
self.last_hidden_state_size = out.last_hidden_state.shape[-1]
self.model = self.model.cpu()
# print("self.last_hidden_state_size: ",self.last_hidden_state_size)
def resize_token_embeddings(self, *args, **kwargs):
"""

View File

@@ -50,7 +50,7 @@ class LoraLinear(lora.LoRALayer, nn.Module):
self.fan_in_fan_out = fan_in_fan_out
# Actual trainable parameters
if r > 0:
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
self.scaling = self.lora_alpha / self.r
# Freezing the pre-trained weight matrix

View File

@@ -5,6 +5,7 @@ loss functions
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from .utils import masked_mean
@@ -201,7 +202,79 @@ class OddsRatioLoss(nn.Module):
chosen_odds_masked = torch.sum(chosen_odds * chosen_loss_mask.float()) / torch.sum(chosen_loss_mask)
reject_odds = reject_logp - torch.log(-torch.exp(reject_logp) + 1.0001)
reject_odds_masked = torch.sum(reject_odds * reject_loss_mask.float()) / torch.sum(reject_loss_mask)
# print("chosen_odds_masked", chosen_odds_masked[0], "reject_odds_masked", reject_odds_masked[0])
log_odds_ratio = chosen_odds_masked - reject_odds_masked
ratio = torch.log(torch.nn.functional.sigmoid(log_odds_ratio))
return ratio.to(dtype=torch.bfloat16), log_odds_ratio
class KTOLoss(nn.Module):
def __init__(self, beta: float = 0.1, desirable_weight: float = 1.0, undesirable_weight: float = 1.0):
"""
Args:
beta: The temperature parameter in the KTO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable
"""
super().__init__()
self.beta = beta
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def forward(
self,
chosen_logps: torch.Tensor,
rejected_logps: torch.Tensor,
kl_logps: torch.Tensor,
ref_chosen_logps: torch.Tensor,
ref_rejected_logps: torch.Tensor,
ref_kl_logps: torch.Tensor,
):
"""
Reference:
https://github.com/huggingface/trl/blob/a2adfb836a90d1e37b1253ab43dace05f1241e04/trl/trainer/kto_trainer.py#L585
Compute the KTO loss for a batch of policy and reference model log probabilities.
Args:
chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
kl_logps: KL divergence of the policy model. Shape: (batch_size,)
ref_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
ref_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
ref_kl_logps: KL divergence of the reference model. Shape: (batch_size,)
beta: The temperature parameter in the DPO paper.
desirable_weight: The weight for the desirable responses.
undesirable_weight: The weight for the undesirable responses.
Refer to the KTO paper for details about hyperparameters https://arxiv.org/pdf/2402.01306
"""
kl = (kl_logps - ref_kl_logps).mean().detach()
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
# kl = 0
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
chosen_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
rejected_logratios = rejected_logps - ref_rejected_logps
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
rejected_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, kl