add SimPO

This commit is contained in:
YeAnbang
2024-06-24 02:12:20 +00:00
parent 84eab13078
commit 82aecd6374
14 changed files with 128 additions and 70 deletions

View File

@@ -88,11 +88,22 @@ class DpoLoss(nn.Module):
"""
Dpo loss
Details: https://arxiv.org/pdf/2305.18290.pdf
SimPO loss:
Details: https://arxiv.org/pdf/2405.14734.pdf
"""
def __init__(self, beta: float = 0.1):
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
"""
Args:
beta: The temperature parameter in the DPO paper.
gamma: The margin parameter in the SimPO paper.
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
Refer to the length normalization in the SimPO paper
"""
super().__init__()
self.beta = beta
self.gamma = gamma
def forward(
self,
@@ -103,7 +114,7 @@ class DpoLoss(nn.Module):
chosen_mask: torch.Tensor,
reject_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
@@ -112,6 +123,8 @@ class DpoLoss(nn.Module):
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
Returns:
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
@@ -126,13 +139,12 @@ class DpoLoss(nn.Module):
if len(logprob_ref_chosen.shape) == 2:
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
else:
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
ref_logratios = logprob_ref_chosen - logprob_ref_reject
else:
# If no reference model is provided
ref_logratios = 0.0
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
logits = pi_logratios - ref_logratios
logits = pi_logratios - ref_logratios - self.gamma / self.beta
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
# Calculate rewards for logging

View File

@@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
return mean
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
def calc_masked_log_probs(
logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False
) -> torch.Tensor:
"""
Calculate the masked log probabilities for a given sequence of logits.
@@ -103,7 +105,13 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas
"""
# logits are probabilities of the next token, so we shift them to the left by one
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs * mask
if not length_normalization:
return log_probs * mask
else:
if torch.any(mask.sum(dim=-1) == 0):
print("Mask should not be all zeros.")
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:

View File

@@ -53,6 +53,8 @@ class DPOTrainer(SLTrainer):
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
@@ -63,7 +65,7 @@ class DPOTrainer(SLTrainer):
self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta)
self.actor_loss_fn = DpoLoss(beta, gamma)
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
@@ -71,6 +73,7 @@ class DPOTrainer(SLTrainer):
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
self.length_normalization = length_normalization
def _before_fit(
self,
@@ -140,9 +143,13 @@ class DPOTrainer(SLTrainer):
)["logits"].to(torch.float32)
actor_chosen_logits = actor_all_logits[:batch_size]
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
if self.ref_model is not None:
self.ref_model.eval()
@@ -154,10 +161,10 @@ class DPOTrainer(SLTrainer):
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
else:
logprob_ref_chosen = None
@@ -288,11 +295,11 @@ class DPOTrainer(SLTrainer):
actor_reject_logits = actor_all_logits[batch_size:]
logprob_actor_chosen = calc_masked_log_probs(
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_actor_reject = calc_masked_log_probs(
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
self.ref_model.eval()
@@ -303,8 +310,12 @@ class DPOTrainer(SLTrainer):
)["logits"].to(torch.float32)
ref_chosen_logits = ref_all_logits[:batch_size]
ref_reject_logits = ref_all_logits[batch_size:]
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
logprob_ref_chosen = calc_masked_log_probs(
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
)
logprob_ref_reject = calc_masked_log_probs(
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
)
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
logprob_actor_chosen,

View File

@@ -102,6 +102,8 @@ class SFTTrainer(SLTrainer):
batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
self.booster.backward(loss=loss, optimizer=self.optimizer)
loss_mean = all_reduce_mean(tensor=loss)