diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 5a488f5aa..1c0773f4e 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -133,7 +133,6 @@ class GRPOConsumer(BaseConsumer): response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 - ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer) with ctx: policy_model_logits = self.policy_model( @@ -243,13 +242,15 @@ class GRPOConsumer(BaseConsumer): ) self.wandb_run.log( { + "metrics/reward": self.accum_reward.item() / self.accum_count, + "metrics/format_reward": self.accum_format_reward.item() / self.accum_count, + "metrics/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "metrics/response_length": self.accum_response_length.item() / self.accum_count, "train/loss": self.accum_loss.item() / self.accum_count, - "train/reward": self.accum_reward.item() / self.accum_count, - "train/format_reward": self.accum_format_reward.item() / self.accum_count, - "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, - "train/response_length": self.accum_response_length.item() / self.accum_count, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0], + "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } ) self.accum_loss.zero_() diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 6cc9b3330..51a1af332 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -101,6 +101,9 @@ class BaseProducer: break outputs = self.rollout(**batch) print(f"[P{self.producer_idx}] Send data {[(k, v.shape) for k, v in outputs.items()]}") + outputs["temperature"] = torch.tensor( + [self.model.generate_config.temperature] * outputs["input_ids"].size(0) + ).to(outputs["input_ids"].device) outputs = pre_send(outputs) ray_broadcast_tensor_dict( outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"