added buffer sync to naive amp model wrapper (#291)

This commit is contained in:
Frank Lee
2022-03-02 16:47:17 +08:00
parent 8d653af408
commit e17e54e32a
4 changed files with 191 additions and 46 deletions

View File

@@ -0,0 +1,84 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
from colossalai.context.parallel_mode import ParallelMode
from torchvision.models import resnet50
import torch.distributed as dist
def run_dist(rank, world_size, port):
# need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
cudnn_determinstic=True,
cudnn_benchmark=False),
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
model = resnet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
engine, *args = colossalai.initialize(model, optimizer, criterion)
# train for dummy iterations
engine.train()
for _ in range(2):
data = torch.rand(4, 3, 128, 128).cuda().half()
label = torch.randint(0, 10, size=(4,)).cuda()
engine.zero_grad()
out = engine(data)
loss = engine.criterion(out, label)
engine.backward(loss)
engine.step()
# test
# need to make sure the batch norm stats are synchronized
# so that given the same input, the model will produce the same
# output on different ranks
engine.eval()
data = torch.rand(4, 3, 128, 128).cuda().half()
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
# predict
out = engine(data)
# test if results are equal
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
tensor_list.insert(rank, out)
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
assert torch.all(tensor_list[0] == tensor_list[1]), \
'expected the output from different ranks to be the same, but got different values'
@pytest.mark.dist
def test_sharded_optim_with_sync_bn():
"""
This test is to make sure that buffers are synchronized between ranks
when using ZeRO. An example of module buffer is the running stats of
BatchNormalization layer, i.e. mean and var.
If the buffers are not synchronized, the model will produce different
output even though the input and parameters are the same. This is not
wanted if we are doing predictions.
"""
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_sharded_optim_with_sync_bn()