[test] make zero engine test really work (#447)

This commit is contained in:
Jiarui Fang
2022-03-17 17:24:25 +08:00
committed by GitHub
parent bb2790cf0b
commit 0fcfb1e00d
7 changed files with 39 additions and 28 deletions

View File

@@ -8,6 +8,7 @@ import pytest
import colossalai
from colossalai.utils import free_port
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
import torch.multiprocessing as mp
import torch.distributed as dist
@@ -32,12 +33,13 @@ def run_dist(rank, world_size, port, parallel_config):
colo_model = model_builder(checkpoint=True)
torch_model = copy.deepcopy(colo_model).cuda()
torch_model.train()
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
optimizer=optimizer_class,
criterion=criterion,
train_dataloader=train_dataloader)
engine.train()
torch_optimizer = optimizer_class(torch_model.parameters())
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
if dist.get_world_size() > 1:
torch_model = DDP(torch_model)
@@ -66,15 +68,17 @@ def run_dist(rank, world_size, port, parallel_config):
engine.step()
torch_loss.backward()
for param in torch_model.parameters():
if param.grad is not None:
assert not has_inf_or_nan(param.grad)
torch_optimizer.step()
i += 1
# for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()):
# assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}"
if parallel_config == MP_PARALLEL_CONFIG:
check_params(torch_model, colo_model, loose=True)
elif isinstance(colo_model, ShardedModelV2):
elif parallel_config == ZERO_PARALLEL_CONFIG:
check_sharded_params_padding(torch_model, colo_model, loose=True)