From 6b06430ca418715240c10a28baf14ff3e4b35ca0 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 19 Jun 2025 01:37:52 +0000 Subject: [PATCH] fix small bug --- .../ColossalChat/coati/distributed/grpo_consumer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5e8f329eb..d7d6221a1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer): if self.booster.plugin.stage_manager.is_last_stage(): reference_action_log_probs = memory_efficient_logprob( - reference_model_outputs["outputs"]["logits"], + reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, shard_config=self.plugin.shard_config, @@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = memory_efficient_logprob( - action_logits, + action_logits / self.generate_config["temperature"], inputs["input_ids"], num_action, shard_config=self.plugin.shard_config, @@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer): reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs)