mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
fix dataloader
This commit is contained in:
@@ -109,8 +109,6 @@ def calc_masked_log_probs(
|
||||
if not length_normalization:
|
||||
return log_probs * mask
|
||||
else:
|
||||
if torch.any(mask.sum(dim=-1) == 0):
|
||||
print("Mask should not be all zeros.")
|
||||
return log_probs * mask / (mask.sum(dim=-1, keepdim=True) + 0.01)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user