del excess code

This commit is contained in:
ElcarimQAQ 2025-04-18 14:19:28 +08:00
parent ad56d16c1d
commit dd28ab45df

View File

@ -240,14 +240,10 @@ class SimpleProducer(BaseProducer):
all_accs = torch.tensor([], device=self.device) all_accs = torch.tensor([], device=self.device)
for test_batch in self.val_dataloader: for test_batch in self.val_dataloader:
# test_batch['input_ids'].size() [32, 300]
# test_batch["gt_answer"] orch.Size([32, 1, 300])
test_output = self.rollout(**test_batch) test_output = self.rollout(**test_batch)
# test_output["response_idx"] torch.Size([32, 8, 2])
num_generations = test_output["response_idx"].size(1) num_generations = test_output["response_idx"].size(1)
print("num_generations", num_generations) print("num_generations", num_generations)
data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()} data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()}
# data = test_output
reward_group = self.reward_model( reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]) data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])