mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
added buffer sync to naive amp model wrapper (#291)
This commit is contained in:
@@ -0,0 +1,84 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from torchvision.models import resnet50
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
# need to configure cudnn deterministic so that
|
||||
# randomness of convolution layers will be disabled
|
||||
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),
|
||||
cudnn_determinstic=True,
|
||||
cudnn_benchmark=False),
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
|
||||
model = resnet50()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
engine, *args = colossalai.initialize(model, optimizer, criterion)
|
||||
|
||||
# train for dummy iterations
|
||||
engine.train()
|
||||
for _ in range(2):
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
label = torch.randint(0, 10, size=(4,)).cuda()
|
||||
engine.zero_grad()
|
||||
out = engine(data)
|
||||
loss = engine.criterion(out, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
|
||||
# test
|
||||
# need to make sure the batch norm stats are synchronized
|
||||
# so that given the same input, the model will produce the same
|
||||
# output on different ranks
|
||||
engine.eval()
|
||||
data = torch.rand(4, 3, 128, 128).cuda().half()
|
||||
dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
# predict
|
||||
out = engine(data)
|
||||
|
||||
# test if results are equal
|
||||
tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)]
|
||||
tensor_list.insert(rank, out)
|
||||
dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
assert torch.all(tensor_list[0] == tensor_list[1]), \
|
||||
'expected the output from different ranks to be the same, but got different values'
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim_with_sync_bn():
|
||||
"""
|
||||
This test is to make sure that buffers are synchronized between ranks
|
||||
when using ZeRO. An example of module buffer is the running stats of
|
||||
BatchNormalization layer, i.e. mean and var.
|
||||
|
||||
If the buffers are not synchronized, the model will produce different
|
||||
output even though the input and parameters are the same. This is not
|
||||
wanted if we are doing predictions.
|
||||
|
||||
"""
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_with_sync_bn()
|
Reference in New Issue
Block a user