Support overall loss, update KTO logging

This commit is contained in:
YeAnbang
2024-08-02 06:51:38 +00:00
parent 75c963686f
commit 0b2d55c4ab
15 changed files with 119 additions and 119 deletions

View File

@@ -56,6 +56,7 @@ class DPOTrainer(SLTrainer):
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class DPOTrainer(SLTrainer):
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -135,6 +137,10 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_all_logits = self.model(
@@ -284,6 +290,9 @@ class DPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]