mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
add SimPO
This commit is contained in:
@@ -88,11 +88,22 @@ class DpoLoss(nn.Module):
|
||||
"""
|
||||
Dpo loss
|
||||
Details: https://arxiv.org/pdf/2305.18290.pdf
|
||||
|
||||
SimPO loss:
|
||||
Details: https://arxiv.org/pdf/2405.14734.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float = 0.1):
|
||||
def __init__(self, beta: float = 0.1, gamma: float = 0.0):
|
||||
"""
|
||||
Args:
|
||||
beta: The temperature parameter in the DPO paper.
|
||||
gamma: The margin parameter in the SimPO paper.
|
||||
length_normalization: Whether to normalize the loss by the length of chosen and rejected responses.
|
||||
Refer to the length normalization in the SimPO paper
|
||||
"""
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -103,7 +114,7 @@ class DpoLoss(nn.Module):
|
||||
chosen_mask: torch.Tensor,
|
||||
reject_mask: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute the DPO loss for a batch of policy and reference model log probabilities.
|
||||
"""Compute the DPO/SimPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
# adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L328
|
||||
|
||||
@@ -112,6 +123,8 @@ class DpoLoss(nn.Module):
|
||||
logprob_actor_reject: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
logprob_ref_chosen: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
|
||||
logprob_ref_reject: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
|
||||
chosen_mask: Mask tensor indicating which responses were chosen. Shape: (batch_size,)
|
||||
reject_mask: Mask tensor indicating which responses were rejected. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
@@ -126,13 +139,12 @@ class DpoLoss(nn.Module):
|
||||
if len(logprob_ref_chosen.shape) == 2:
|
||||
ref_logratios = logprob_ref_chosen.sum(-1) - logprob_ref_reject.sum(-1)
|
||||
else:
|
||||
ref_logratios = logprob_ref_chosen.squeeze() - logprob_ref_reject.squeeze()
|
||||
ref_logratios = logprob_ref_chosen - logprob_ref_reject
|
||||
else:
|
||||
# If no reference model is provided
|
||||
ref_logratios = 0.0
|
||||
|
||||
pi_logratios = logprob_actor_chosen.sum(-1) - logprob_actor_reject.sum(-1)
|
||||
logits = pi_logratios - ref_logratios
|
||||
logits = pi_logratios - ref_logratios - self.gamma / self.beta
|
||||
losses = -torch.nn.functional.logsigmoid(self.beta * logits)
|
||||
|
||||
# Calculate rewards for logging
|
||||
|
@@ -89,7 +89,9 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
||||
return mean
|
||||
|
||||
|
||||
def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
def calc_masked_log_probs(
|
||||
logits: torch.Tensor, sequences: torch.LongTensor, mask: torch.Tensor, length_normalization: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the masked log probabilities for a given sequence of logits.
|
||||
|
||||
@@ -103,7 +105,13 @@ def calc_masked_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, mas
|
||||
"""
|
||||
# logits are probabilities of the next token, so we shift them to the left by one
|
||||
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs * mask
|
||||
|
||||
if not length_normalization:
|
||||
return log_probs * mask
|
||||
else:
|
||||
if torch.any(mask.sum(dim=-1) == 0):
|
||||
print("Mask should not be all zeros.")
|
||||
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
|
@@ -53,6 +53,8 @@ class DPOTrainer(SLTrainer):
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
beta: float = 0.1,
|
||||
gamma: float = 0.0,
|
||||
length_normalization: bool = False,
|
||||
accumulation_steps: int = 1,
|
||||
start_epoch: int = 0,
|
||||
save_interval: int = 0,
|
||||
@@ -63,7 +65,7 @@ class DPOTrainer(SLTrainer):
|
||||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.actor_loss_fn = DpoLoss(beta)
|
||||
self.actor_loss_fn = DpoLoss(beta, gamma)
|
||||
self.save_interval = save_interval
|
||||
self.coordinator = coordinator
|
||||
self.save_dir = save_dir
|
||||
@@ -71,6 +73,7 @@ class DPOTrainer(SLTrainer):
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.device = get_current_device()
|
||||
self.accumulative_meter = AccumulativeMeanMeter()
|
||||
self.length_normalization = length_normalization
|
||||
|
||||
def _before_fit(
|
||||
self,
|
||||
@@ -140,9 +143,13 @@ class DPOTrainer(SLTrainer):
|
||||
)["logits"].to(torch.float32)
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
@@ -154,10 +161,10 @@ class DPOTrainer(SLTrainer):
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
else:
|
||||
logprob_ref_chosen = None
|
||||
@@ -288,11 +295,11 @@ class DPOTrainer(SLTrainer):
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:]
|
||||
actor_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
logprob_actor_reject = calc_masked_log_probs(
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:]
|
||||
actor_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
self.ref_model.eval()
|
||||
@@ -303,8 +310,12 @@ class DPOTrainer(SLTrainer):
|
||||
)["logits"].to(torch.float32)
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:])
|
||||
logprob_ref_reject = calc_masked_log_probs(ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:])
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
ref_chosen_logits, chosen_input_ids, chosen_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
logprob_ref_reject = calc_masked_log_probs(
|
||||
ref_reject_logits, reject_input_ids, reject_loss_mask[:, 1:], self.length_normalization
|
||||
)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.actor_loss_fn(
|
||||
logprob_actor_chosen,
|
||||
|
@@ -102,6 +102,8 @@ class SFTTrainer(SLTrainer):
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
||||
loss = outputs.loss
|
||||
step_bar.set_description(f"Epoch {epoch + 1}/{self.max_epochs} Loss: {loss.detach().cpu().item():.4f}")
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
|
Reference in New Issue
Block a user