diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5e8f329eb..d7d6221a1 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -281,7 +281,7 @@ class GRPOConsumer(BaseConsumer): if self.booster.plugin.stage_manager.is_last_stage(): reference_action_log_probs = memory_efficient_logprob( - reference_model_outputs["outputs"]["logits"], + reference_model_outputs["outputs"]["logits"] / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, shard_config=self.plugin.shard_config, @@ -308,7 +308,7 @@ class GRPOConsumer(BaseConsumer): def _criterion(outputs, inputs): action_logits = outputs.logits action_log_probs = memory_efficient_logprob( - action_logits, + action_logits / self.generate_config["temperature"], inputs["input_ids"], num_action, shard_config=self.plugin.shard_config, @@ -375,7 +375,7 @@ class GRPOConsumer(BaseConsumer): reference_model_logits / self.generate_config["temperature"], input_ids_forward_micro_batch, num_action, - self.plugin.shard_config, + shard_config=self.plugin.shard_config, ) per_token_kl = ( torch.exp(reference_action_log_probs - action_log_probs)