fix dataloader

This commit is contained in:
YeAnbang
2024-06-24 05:10:44 +00:00
parent 4b59d874df
commit 0b2d6275c4
5 changed files with 85 additions and 63 deletions

View File

@@ -109,8 +109,6 @@ def calc_masked_log_probs(
if not length_normalization:
return log_probs * mask
else:
if torch.any(mask.sum(dim=-1) == 0):
print("Mask should not be all zeros.")
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)