mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
Support overall loss, update KTO logging
This commit is contained in:
@@ -102,6 +102,7 @@ class PPOTrainer(OLTrainer):
|
||||
sample_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
offload_inference_models: bool = True,
|
||||
apply_loss_mask: bool = True,
|
||||
accumulation_steps: int = 1,
|
||||
save_interval: int = 0,
|
||||
save_dir: str = None,
|
||||
@@ -140,6 +141,7 @@ class PPOTrainer(OLTrainer):
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
self.save_interval = save_interval
|
||||
self.apply_loss_mask = apply_loss_mask
|
||||
self.coordinator = coordinator
|
||||
self.actor_save_dir = os.path.join(save_dir, "actor")
|
||||
self.critic_save_dir = os.path.join(save_dir, "critic")
|
||||
@@ -229,7 +231,10 @@ class PPOTrainer(OLTrainer):
|
||||
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)
|
||||
|
||||
actor_loss, to_skip, max_ratio = self.actor_loss_fn(
|
||||
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
|
||||
action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
actor_loss = (1 - self.ptx_coef) * actor_loss
|
||||
if not to_skip:
|
||||
@@ -249,7 +254,10 @@ class PPOTrainer(OLTrainer):
|
||||
input_ids=experience.sequences, attention_mask=experience.attention_mask
|
||||
) # [batch size, prompt_length + response_length]
|
||||
critic_loss = self.critic_loss_fn(
|
||||
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
|
||||
values[:, -num_actions:],
|
||||
experience.values,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask if self.apply_loss_mask else None,
|
||||
)
|
||||
critic_loss = critic_loss * self.vf_coef
|
||||
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
|
||||
|
Reference in New Issue
Block a user