fix small bug

This commit is contained in:
YeAnbang 2025-06-19 01:37:52 +00:00
parent e3d56cbd86
commit 6b06430ca4

View File

@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage():
reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"],
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
shard_config=self.plugin.shard_config,
@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs):
action_logits = outputs.logits
action_log_probs = memory_efficient_logprob(
action_logits,
action_logits / self.generate_config["temperature"],
inputs["input_ids"],
num_action,
shard_config=self.plugin.shard_config,
@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer):
reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch,
num_action,
self.plugin.shard_config,
shard_config=self.plugin.shard_config,
)
per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs)