mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 00:47:13 +00:00
fix small bug
This commit is contained in:
parent
e3d56cbd86
commit
6b06430ca4
@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
|
|
||||||
if self.booster.plugin.stage_manager.is_last_stage():
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
reference_action_log_probs = memory_efficient_logprob(
|
reference_action_log_probs = memory_efficient_logprob(
|
||||||
reference_model_outputs["outputs"]["logits"],
|
reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
shard_config=self.plugin.shard_config,
|
shard_config=self.plugin.shard_config,
|
||||||
@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
def _criterion(outputs, inputs):
|
def _criterion(outputs, inputs):
|
||||||
action_logits = outputs.logits
|
action_logits = outputs.logits
|
||||||
action_log_probs = memory_efficient_logprob(
|
action_log_probs = memory_efficient_logprob(
|
||||||
action_logits,
|
action_logits / self.generate_config["temperature"],
|
||||||
inputs["input_ids"],
|
inputs["input_ids"],
|
||||||
num_action,
|
num_action,
|
||||||
shard_config=self.plugin.shard_config,
|
shard_config=self.plugin.shard_config,
|
||||||
@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
reference_model_logits / self.generate_config["temperature"],
|
reference_model_logits / self.generate_config["temperature"],
|
||||||
input_ids_forward_micro_batch,
|
input_ids_forward_micro_batch,
|
||||||
num_action,
|
num_action,
|
||||||
self.plugin.shard_config,
|
shard_config=self.plugin.shard_config,
|
||||||
)
|
)
|
||||||
per_token_kl = (
|
per_token_kl = (
|
||||||
torch.exp(reference_action_log_probs - action_log_probs)
|
torch.exp(reference_action_log_probs - action_log_probs)
|
||||||
|
Loading…
Reference in New Issue
Block a user