fix small bug

This commit is contained in:
YeAnbang 2025-06-19 01:37:52 +00:00
parent e3d56cbd86
commit 6b06430ca4

View File

@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer):
if self.booster.plugin.stage_manager.is_last_stage(): if self.booster.plugin.stage_manager.is_last_stage():
reference_action_log_probs = memory_efficient_logprob( reference_action_log_probs = memory_efficient_logprob(
reference_model_outputs["outputs"]["logits"], reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
input_ids_forward_micro_batch, input_ids_forward_micro_batch,
num_action, num_action,
shard_config=self.plugin.shard_config, shard_config=self.plugin.shard_config,
@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer):
def _criterion(outputs, inputs): def _criterion(outputs, inputs):
action_logits = outputs.logits action_logits = outputs.logits
action_log_probs = memory_efficient_logprob( action_log_probs = memory_efficient_logprob(
action_logits, action_logits / self.generate_config["temperature"],
inputs["input_ids"], inputs["input_ids"],
num_action, num_action,
shard_config=self.plugin.shard_config, shard_config=self.plugin.shard_config,
@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer):
reference_model_logits / self.generate_config["temperature"], reference_model_logits / self.generate_config["temperature"],
input_ids_forward_micro_batch, input_ids_forward_micro_batch,
num_action, num_action,
self.plugin.shard_config, shard_config=self.plugin.shard_config,
) )
per_token_kl = ( per_token_kl = (
torch.exp(reference_action_log_probs - action_log_probs) torch.exp(reference_action_log_probs - action_log_probs)