mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-22 09:49:14 +00:00
del excess code
This commit is contained in:
parent
ad56d16c1d
commit
dd28ab45df
@ -240,14 +240,10 @@ class SimpleProducer(BaseProducer):
|
||||
all_accs = torch.tensor([], device=self.device)
|
||||
|
||||
for test_batch in self.val_dataloader:
|
||||
# test_batch['input_ids'].size() [32, 300]
|
||||
# test_batch["gt_answer"] orch.Size([32, 1, 300])
|
||||
test_output = self.rollout(**test_batch)
|
||||
# test_output["response_idx"] torch.Size([32, 8, 2])
|
||||
num_generations = test_output["response_idx"].size(1)
|
||||
print("num_generations", num_generations)
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in test_output.items()}
|
||||
# data = test_output
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user