precision alignment

This commit is contained in:
yuehuayingxueluo
2024-01-02 18:30:11 +08:00
committed by FrankLeeeee
parent 62968588d1
commit 9489dc64d8
5 changed files with 45 additions and 47 deletions

View File

@@ -308,7 +308,7 @@ class BatchInfo:
input_len_list.append(1)
return torch.tensor(input_list, dtype=torch.long, device=self.device), torch.tensor(
input_len_list, dtype=torch.int, device=device
input_len_list, dtype=torch.int, device=self.device
)
def get_sequence_lengths(self):