mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
This commit is contained in:
@@ -139,7 +139,7 @@ class DPOTrainer(SLTrainer):
|
||||
actor_all_logits = self.model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
logprob_actor_chosen = calc_masked_log_probs(
|
||||
@@ -156,7 +156,7 @@ class DPOTrainer(SLTrainer):
|
||||
ref_all_logits = self.ref_model(
|
||||
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
attention_mask=torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
@@ -225,7 +225,7 @@ class DPOTrainer(SLTrainer):
|
||||
)
|
||||
self.accumulative_meter.reset()
|
||||
|
||||
if (self.num_train_step + 1) % self.save_interval == 0:
|
||||
if self.save_dir is not None and (self.num_train_step + 1) % self.save_interval == 0:
|
||||
# save checkpoint
|
||||
self.coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
||||
save_checkpoint(
|
||||
@@ -289,7 +289,7 @@ class DPOTrainer(SLTrainer):
|
||||
actor_all_logits = self.model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
actor_chosen_logits = actor_all_logits[:batch_size]
|
||||
actor_reject_logits = actor_all_logits[batch_size:]
|
||||
|
||||
@@ -306,7 +306,7 @@ class DPOTrainer(SLTrainer):
|
||||
ref_all_logits = self.ref_model(
|
||||
torch.cat([chosen_input_ids, reject_input_ids]),
|
||||
torch.cat([chosen_attention_mask, reject_attention_mask]),
|
||||
)["logits"].to(torch.float32)
|
||||
)["logits"]
|
||||
ref_chosen_logits = ref_all_logits[:batch_size]
|
||||
ref_reject_logits = ref_all_logits[batch_size:]
|
||||
logprob_ref_chosen = calc_masked_log_probs(
|
||||
|
Reference in New Issue
Block a user