mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-11 14:24:27 +00:00
remove unused code
This commit is contained in:
parent
09a3173a49
commit
061d8cb3b6
@ -252,10 +252,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
- (inputs["reference_action_log_probs"] - action_log_probs)
|
- (inputs["reference_action_log_probs"] - action_log_probs)
|
||||||
- 1
|
- 1
|
||||||
)
|
)
|
||||||
decode_tokens_100 = self.tokenizer.batch_decode(
|
|
||||||
input_ids_forward_micro_batch[:, -num_action:],
|
|
||||||
skip_special_tokens=False,
|
|
||||||
)
|
|
||||||
loss, skip_update, _ = self.policy_loss_fn(
|
loss, skip_update, _ = self.policy_loss_fn(
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
action_log_probs,
|
action_log_probs,
|
||||||
@ -277,7 +273,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss = policy_model_outputs["loss"]
|
loss = policy_model_outputs["loss"]
|
||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
# calculate kl
|
# calculate kl, as we cannot do this inside callback, kl needs be calculate again
|
||||||
action_logits = policy_model_outputs["outputs"]["logits"]
|
action_logits = policy_model_outputs["outputs"]["logits"]
|
||||||
action_log_probs = calc_action_log_probs(
|
action_log_probs = calc_action_log_probs(
|
||||||
action_logits / self.generate_config["temperature"],
|
action_logits / self.generate_config["temperature"],
|
||||||
|
Loading…
Reference in New Issue
Block a user