mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
fixed zero level 3 dtype bug (#76)
This commit is contained in:
@@ -89,10 +89,10 @@ def run_dist(rank, world_size):
|
||||
model.train()
|
||||
for idx, (data, label) in enumerate(train_dataloader):
|
||||
engine.zero_grad()
|
||||
data = data.cuda().half()
|
||||
data = data.cuda()
|
||||
label = label.cuda()
|
||||
|
||||
output = engine(data).float()
|
||||
output = engine(data)
|
||||
loss = engine.criterion(output, label)
|
||||
|
||||
engine.backward(loss)
|
||||
@@ -104,7 +104,6 @@ def run_dist(rank, world_size):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||
def test_zero_level_3():
|
||||
world_size = 4
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
|
@@ -108,7 +108,6 @@ def run_2d_parallel_vision_transformer_level_3(rank, world_size):
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("Level 3 has unknown bug so skip this test for now")
|
||||
def test_3d_vit_zero_level_3():
|
||||
world_size = 8
|
||||
run_func = partial(run_2d_parallel_vision_transformer_level_3, world_size=world_size)
|
||||
|
Reference in New Issue
Block a user