[booster] added the plugin base and torch ddp plugin (#3180)

* [booster] added the plugin base and torch ddp plugin

* polish code

* polish code

* polish code
This commit is contained in:
Frank Lee
2023-03-21 17:39:30 +08:00
committed by GitHub
parent e5f668f280
commit e7f3bed2d3
8 changed files with 378 additions and 86 deletions

View File

@@ -1,13 +1,27 @@
import pytest
from functools import partial
import torch.multiprocessing as mp
import torch.nn as nn
from torchvision.models import resnet18
from colossalai.booster.accelerator import Accelerator
from colossalai.testing import parameterize, rerun_if_address_is_in_use
@pytest.mark.parametrize('device', ['cpu', 'cuda'])
def test_accelerator(device):
@parameterize('device', ['cpu', 'cuda'])
def run_accelerator(device):
acceleartor = Accelerator(device)
model = nn.Linear(8, 8)
model = acceleartor.configure_model(model)
assert next(model.parameters()).device.type == device
del model, acceleartor
def run_dist(rank):
run_accelerator()
@rerun_if_address_is_in_use()
def test_accelerator():
world_size = 1
run_func = partial(run_dist)
mp.spawn(run_func, nprocs=world_size)

View File

@@ -1,12 +1,21 @@
from functools import partial
import torch
import torch.multiprocessing as mp
from torch.optim import Adam
import colossalai
from colossalai.booster.mixed_precision import FP16TorchMixedPrecision
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
def test_torch_amp():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue
@@ -27,3 +36,11 @@ def test_torch_amp():
optimizer.backward(loss)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
del model, optimizer, criterion, data, output, mixed_precision
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 1
run_func = partial(run_torch_amp, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)

View File

@@ -0,0 +1,85 @@
from functools import partial
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import SGD
import colossalai
from colossalai.booster import Booster
from colossalai.booster.interface import OptimizerWrapper
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from tests.kit.model_zoo import model_zoo
def check_torch_ddp_plugin():
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
model = model_fn()
optimizer = SGD(model.parameters(), lr=1e-3)
criterion = lambda x: x.mean()
data = data_gen_fn()
data = {
k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()
}
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
assert isinstance(model, DDP)
assert isinstance(optimizer, OptimizerWrapper)
output = model(**data)
output = output_transform_fn(output)
output_key = list(output.keys())[0]
loss = criterion(output[output_key])
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
def check_dataloader_sharding():
plugin = TorchDDPPlugin()
# create a custom dasetset with 0 to 10
dataset = torch.utils.data.TensorDataset(torch.arange(0, 10))
train_dataloader = plugin.prepare_train_dataloader(dataset, batch_size=2)
# get the first batch of data
batch = next(iter(train_dataloader))[0].cuda()
is_rank_0 = dist.get_rank() == 0
if is_rank_0:
batch_to_compare = batch.clone()
else:
batch_to_compare = batch
# pass to the rank 1 value to rank 0
dist.broadcast(batch_to_compare, src=1)
# compare on rank 0
if is_rank_0:
assert not torch.equal(batch,
batch_to_compare), 'Same number was found across ranks but expected it to be different'
def run_dist(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
check_dataloader_sharding()
check_torch_ddp_plugin()
@rerun_if_address_is_in_use()
def test_torch_ddp_plugin():
world_size = 2
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)