Support overall loss, update KTO logging

This commit is contained in:
YeAnbang
2024-08-02 06:51:38 +00:00
parent 75c963686f
commit 0b2d55c4ab
15 changed files with 119 additions and 119 deletions

View File

@@ -52,6 +52,7 @@ class ORPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -67,6 +68,7 @@ class ORPOTrainer(SLTrainer):
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
@@ -130,6 +132,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
@@ -263,6 +270,11 @@ class ORPOTrainer(SLTrainer):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)
batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),