mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
fix style, add kto data sample
This commit is contained in:
@@ -251,17 +251,13 @@ class KTOLoss(nn.Module):
|
||||
# all gather
|
||||
dist.all_reduce(kl, op=dist.ReduceOp.SUM)
|
||||
kl = (kl / dist.get_world_size()).clamp(min=0)
|
||||
# kl = 0
|
||||
|
||||
if chosen_logps.shape[0] != 0 and ref_chosen_logps.shape[0] != 0:
|
||||
chosen_logratios = chosen_logps - ref_chosen_logps
|
||||
chosen_losses = 1 - nn.functional.sigmoid(self.beta * (chosen_logratios - kl))
|
||||
chosen_rewards = self.beta * chosen_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
chosen_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(chosen_logps.dtype).to(chosen_logps.device)
|
||||
chosen_losses = torch.Tensor([]).to(kl_logps.device)
|
||||
chosen_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
if rejected_logps.shape[0] != 0 and ref_rejected_logps.shape[0] != 0:
|
||||
@@ -269,10 +265,7 @@ class KTOLoss(nn.Module):
|
||||
rejected_losses = 1 - nn.functional.sigmoid(self.beta * (kl - rejected_logratios))
|
||||
rejected_rewards = self.beta * rejected_logratios.detach()
|
||||
else:
|
||||
# important to cast to policy_dtype; otherwise error will occur during all_gather
|
||||
rejected_losses = torch.Tensor([]).to(
|
||||
kl_logps.device
|
||||
) # torch.Tensor(0.).to(rejected_logps.dtype).to(rejected_logps.device)
|
||||
rejected_losses = torch.Tensor([]).to(kl_logps.device)
|
||||
rejected_rewards = torch.Tensor([]).to(kl_logps.device)
|
||||
|
||||
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
|
||||
|
Reference in New Issue
Block a user