fix style, add kto data sample

This commit is contained in:
YeAnbang
2024-07-18 08:38:56 +00:00
parent 845ea7214e
commit 544b7a38a1
11 changed files with 68 additions and 53 deletions

View File

@@ -251,17 +251,13 @@ class KTOLoss(nn.Module):
# all gather
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
kl = (kl / dist.get_world_size()).clamp(min=0)
# kl = 0
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
chosen_logratios = chosen_logps - ref_chosen_logps
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
chosen_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
chosen_losses = torch.Tensor([]).to(kl_logps.device)
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
@@ -269,10 +265,7 @@ class KTOLoss(nn.Module):
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
# important to cast to policy_dtype; otherwise error will occur during all_gather
rejected_losses = torch.Tensor([]).to(
kl_logps.device
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
rejected_losses = torch.Tensor([]).to(kl_logps.device)
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()