diff --git a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py index b43ba65c0..ba83f7787 100644 --- a/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/verifiable_reward.py @@ -21,7 +21,7 @@ class VerifiableReward: # Get batch size bs = input_ids.size(0) # Initialize reward - rewards = torch.zeros(bs, device=input_ids.device) + rewards = torch.zeros((bs, 3), device=input_ids.device) # Loop through reward functions for reward_fn in self.reward_fns: