diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 68a01528e..e23254d1b 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -285,7 +285,6 @@ class GRPOConsumer(BaseConsumer): kl = all_reduce_mean(torch.mean(torch.stack(kl)).to(loss.device), self.plugin).data mean_kl.append(kl) mean_loss.append(all_reduce_mean(loss, self.plugin).data) - torch.cuda.empty_cache() else: policy_model_logits = self.policy_model(