[chat] fix bugs in stage 3 training (#3759)

Co-authored-by: Yuanchen Xu <yuanchen.xu00@gmail.com>
This commit is contained in:
Yuanchen
2023-05-17 17:44:05 +08:00
committed by GitHub
parent 5dd573c6b6
commit 05759839bd
4 changed files with 32 additions and 14 deletions

View File

@@ -45,7 +45,7 @@ class PromptDataset(Dataset):
self.keyed_prompt[k].extend(tensor.to(torch.cuda.current_device()).unbind())
def __len__(self):
return len(self.keyed_prompt)
return len(self.keyed_prompt["input_ids"])
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return {k: v[i] for k, v in self.keyed_prompt.items()}