mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-21 09:29:47 +00:00
[chat] fix bugs in stage 3 training (#3759)
Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
@@ -45,7 +45,7 @@ class PromptDataset(Dataset):
|
||||
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.keyed_prompt)
|
||||
return len(self.keyed_prompt["input_ids"])
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return {k: v[i] for k, v in self.keyed_prompt.items()}
|
||||
|
Reference in New Issue
Block a user