mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-19 00:17:18 +00:00
fix small bug
This commit is contained in:
parent
e3d56cbd86
commit
6b06430ca4
@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_action_log_probs = memory_efficient_logprob(
|
||||
reference_model_outputs["outputs"]["logits"],
|
||||
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits,
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
|
Loading…
Reference in New Issue
Block a user