mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-25 10:06:27 +00:00
[test] make zero engine test really work (#447)
This commit is contained in:
@@ -8,6 +8,7 @@ import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
import torch.distributed as dist
|
||||
@@ -32,12 +33,13 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
torch_model = copy.deepcopy(colo_model).cuda()
|
||||
torch_model.train()
|
||||
engine, train_dataloader, _, _ = colossalai.initialize(colo_model,
|
||||
optimizer=optimizer_class,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
engine.train()
|
||||
torch_optimizer = optimizer_class(torch_model.parameters())
|
||||
torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
if dist.get_world_size() > 1:
|
||||
torch_model = DDP(torch_model)
|
||||
@@ -66,15 +68,17 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||
engine.step()
|
||||
|
||||
torch_loss.backward()
|
||||
|
||||
for param in torch_model.parameters():
|
||||
if param.grad is not None:
|
||||
assert not has_inf_or_nan(param.grad)
|
||||
|
||||
torch_optimizer.step()
|
||||
i += 1
|
||||
|
||||
# for torch_param, zero_param in zip(torch_model.parameters(), colo_model.parameters()):
|
||||
# assert torch.allclose(torch_param, zero_param), f"diff {torch_param - zero_param}"
|
||||
|
||||
if parallel_config == MP_PARALLEL_CONFIG:
|
||||
check_params(torch_model, colo_model, loose=True)
|
||||
elif isinstance(colo_model, ShardedModelV2):
|
||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
||||
check_sharded_params_padding(torch_model, colo_model, loose=True)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user