add SimPO

This commit is contained in:
YeAnbang
2024-06-24 02:12:20 +00:00
parent 84eab13078
commit 82aecd6374
14 changed files with 128 additions and 70 deletions

View File

@@ -53,6 +53,8 @@ class DPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -63,7 +65,7 @@ class DPOTrainer(SLTrainer):
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta)
self.actor_loss_fn = DpoLoss(beta, gamma)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -71,6 +73,7 @@ class DPOTrainer(SLTrainer):
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
self.length_normalization = length_normalization
def _before_fit(
self,
@@ -140,9 +143,13 @@ class DPOTrainer(SLTrainer):
)["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
if self.ref_model is not None:
self.ref_model.eval()
@@ -154,10 +161,10 @@ class DPOTrainer(SLTrainer):
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
@@ -288,11 +295,11 @@ class DPOTrainer(SLTrainer):
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
self.ref_model.eval()
@@ -303,8 +310,12 @@ class DPOTrainer(SLTrainer):
)["logits"].to(torch.float32)
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,