diff --git a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py index b3cfa60bd..d398c4365 100644 --- a/tests/test_layers/test_3d/checks_3d/check_layer_3d.py +++ b/tests/test_layers/test_3d/checks_3d/check_layer_3d.py @@ -784,7 +784,7 @@ def check_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i] @@ -837,7 +837,7 @@ def check_vocab_parallel_loss(): out_shape = (BATCH_SIZE, NUM_CLASSES) out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device) torch.distributed.broadcast(out_master, src=0) torch.distributed.broadcast(target_master, src=0) out = torch.chunk(out_master, DEPTH, dim=0)[i]