mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 00:47:13 +00:00
fix small bug
This commit is contained in:
parent
245c8c2fbc
commit
b314da19f4
@ -294,7 +294,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
reference_action_log_probs = memory_efficient_logprob(
|
||||
reference_model_outputs["outputs"]["logits"],
|
||||
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
@ -321,7 +321,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
def _criterion(outputs, inputs):
|
||||
action_logits = outputs.logits
|
||||
action_log_probs = memory_efficient_logprob(
|
||||
action_logits,
|
||||
action_logits / self.generate_config["temperature"],
|
||||
inputs["input_ids"],
|
||||
num_action,
|
||||
shard_config=self.plugin.shard_config,
|
||||
@ -388,7 +388,7 @@ class GRPOConsumer(BaseConsumer):
|
||||
reference_model_logits / self.generate_config["temperature"],
|
||||
input_ids_forward_micro_batch,
|
||||
num_action,
|
||||
self.plugin.shard_config,
|
||||
shard_config=self.plugin.shard_config,
|
||||
)
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
|
Loading…
Reference in New Issue
Block a user