diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 8408693b5..0535fbd78 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -240,14 +240,10 @@ class SimpleProducer(BaseProducer): all_accs = torch.tensor([], device=self.device) for test_batch in self.val_dataloader: - # test_batch['input_ids'].size() [32, 300] - # test_batch["gt_answer"] orch.Size([32, 1, 300]) test_output = self.rollout(**test_batch) - # test_output["response_idx"] torch.Size([32, 8, 2]) num_generations = test_output["response_idx"].size(1) print("num_generations", num_generations) data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()} - # data = test_output reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])