fixed zero level 3 dtype bug (#76)

This commit is contained in:
Frank Lee
2021-12-20 17:00:53 +08:00
committed by GitHub
parent 632e622de8
commit 91c327cb44
5 changed files with 16 additions and 12 deletions

View File

@@ -89,10 +89,10 @@ def run_dist(rank, world_size):
model.train()
for idx, (data, label) in enumerate(train_dataloader):
engine.zero_grad()
data = data.cuda().half()
data = data.cuda()
label = label.cuda()
output = engine(data).float()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
@@ -104,7 +104,6 @@ def run_dist(rank, world_size):
@pytest.mark.dist
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size)

View File

@@ -108,7 +108,6 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
@pytest.mark.dist
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
def test_3d_vit_zero_level_3():
world_size = 8
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)