diff --git a/tests/components_to_test/nested_model.py b/tests/components_to_test/nested_model.py index 5f32b08e9..edf4a1a89 100644 --- a/tests/components_to_test/nested_model.py +++ b/tests/components_to_test/nested_model.py @@ -43,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator): @non_distributed_component_funcs.register(name='nested_model') def get_training_components(): - def model_builder(checkpoint): + def model_builder(checkpoint=True): return NestedNet(checkpoint) trainloader = DummyDataLoader() diff --git a/tests/components_to_test/utils/dummy_data_generator.py b/tests/components_to_test/utils/dummy_data_generator.py index aabcd30e4..5ab33e86d 100644 --- a/tests/components_to_test/utils/dummy_data_generator.py +++ b/tests/components_to_test/utils/dummy_data_generator.py @@ -3,12 +3,23 @@ from abc import ABC, abstractmethod class DummyDataGenerator(ABC): + def __init__(self, length=10): + self.length = length + @abstractmethod def generate(self): pass def __iter__(self): + self.step = 0 return self def __next__(self): - return self.generate() + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length diff --git a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py index 599efd883..9ae21cf77 100644 --- a/tests/test_trainer/test_trainer_with_non_pipe_schedule.py +++ b/tests/test_trainer/test_trainer_with_non_pipe_schedule.py @@ -1,21 +1,14 @@ -import os from functools import partial -from pathlib import Path import colossalai import pytest import torch import torch.multiprocessing as mp -import torch.nn as nn from colossalai.amp.amp_type import AMP_TYPE -from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.trainer import Trainer -from colossalai.utils import MultiTimer, free_port, get_dataloader -from torch.optim import Adam -from torchvision import transforms -from torchvision.datasets import CIFAR10 -from torchvision.models import resnet18 +from colossalai.utils import MultiTimer, free_port +from tests.components_to_test.registry import non_distributed_component_funcs BATCH_SIZE = 16 IMG_SIZE = 32 @@ -29,60 +22,32 @@ CONFIG = dict( def run_trainer_no_pipeline(rank, world_size, port): colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - # build model - model = resnet18(num_classes=10) + test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] + for name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(name) + model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func() + model = model_builder() + optimizer = optimizer_builder(model) + engine, train_dataloader, *_ = colossalai.initialize(model=model, + optimizer=optimizer, + criterion=criterion, + train_dataloader=train_dataloader) - # build dataloaders - train_dataset = CIFAR10(root=Path(os.environ['DATA']), - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) + logger = get_dist_logger() + logger.info("engine is built", ranks=[0]) - test_dataset = CIFAR10(root=Path(os.environ['DATA']), - train=False, - download=True, - transform=transforms.Compose([ - transforms.Resize(size=(IMG_SIZE, IMG_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)) - ])) + timer = MultiTimer() + trainer = Trainer(engine=engine, logger=logger, timer=timer) + logger.info("trainer is built", ranks=[0]) - train_dataloader = get_dataloader(dataset=train_dataset, - shuffle=True, - batch_size=BATCH_SIZE, - pin_memory=True, - drop_last=True) - - test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True) - - # build optimizer - optimizer = Adam(model.parameters(), lr=0.001) - criterion = nn.CrossEntropyLoss() - - engine, train_dataloader, *args = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - - logger = get_dist_logger() - logger.info("engine is built", ranks=[0]) - - timer = MultiTimer() - trainer = Trainer(engine=engine, logger=logger, timer=timer) - logger.info("trainer is built", ranks=[0]) - - logger.info("start training", ranks=[0]) - trainer.fit(train_dataloader=train_dataloader, - test_dataloader=test_dataloader, - epochs=NUM_EPOCHS, - max_steps=100, - display_progress=True, - test_interval=5) - gpc.destroy() - torch.cuda.empty_cache() + logger.info("start training", ranks=[0]) + trainer.fit(train_dataloader=train_dataloader, + test_dataloader=test_dataloader, + epochs=NUM_EPOCHS, + max_steps=5, + display_progress=True, + test_interval=5) + torch.cuda.empty_cache() @pytest.mark.dist diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index c20dcd89c..eb1b267f6 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -20,6 +20,7 @@ from torch.nn.utils import clip_grad_norm_ class Enumerator: + def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None: self.arg_names = arg_names self.enums = Enumerator.all_enumerate(arg_values) @@ -49,11 +50,12 @@ class Enumerator: def checkpoint_wrapper(module, enable=True): if enable: - module.forward = partial(checkpoint, module.forward) + module.forward = partial(checkpoint, module.forward, False) return module class Net(nn.Module): + def __init__(self, checkpoint=False) -> None: super().__init__() self.fc1 = nn.Linear(5, 5) @@ -61,13 +63,7 @@ class Net(nn.Module): self.fc3 = nn.Linear(5, 1) if checkpoint: self.fc1 = checkpoint_wrapper(self.fc1) - self.layers = [ - self.fc1, - self.fc2, - self.fc1, - self.fc2, - self.fc3 - ] + self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] def forward(self, x): for layer in self.layers: @@ -158,12 +154,7 @@ def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0): def run_dist(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config={}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') args = ['checkpoint', 'fp16', 'offload', 'norm_type'] arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))] @@ -176,7 +167,7 @@ def run_dist(rank, world_size, port): check_config() -@ pytest.mark.dist +@pytest.mark.dist def test_zero_clip_grad(): world_size = 4 run_func = partial(run_dist, world_size=world_size, port=free_port()) diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index 163b098c0..f9fe86bf6 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -43,23 +43,6 @@ def checkpoint_wrapper(module, enable=True): return module -class Net(nn.Module): - - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 1) - if checkpoint: - self.fc1 = checkpoint_wrapper(self.fc1) - self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: if loose: return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index ce564be46..5c70e5274 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -13,7 +13,8 @@ from colossalai.utils import free_port from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 -from tests.test_zero_data_parallel.common import CONFIG, Net, allclose +from tests.test_zero_data_parallel.common import CONFIG, allclose +from tests.components_to_test.registry import non_distributed_component_funcs def _run_shard_tensor(rank, world_size, port): @@ -68,21 +69,22 @@ def _run_test_shard_param(rank, world_size, port): print(param_ref.data) logger = get_dist_logger() - model = Net() + for get_components_func in non_distributed_component_funcs: + model_builder, *_ = get_components_func() + model = model_builder(checkpoint=True) + # add an attribute as col_attr to hijack the access to param.data + for _, param in model.named_parameters(): + numel_ref = (param.numel() + world_size - 1) // world_size + param.col_attr = ShardedParam(param) + param.col_attr.shard() + param_data = param.col_attr.payload(torch.device('cpu')) + assert (numel_ref == param_data.numel()) - # add an attribute as ca_attr to hijack the access to param.data - for _, param in model.named_parameters(): - numel_ref = (param.numel() + world_size - 1) // world_size - param.ca_attr = ShardedParam(param) - param.ca_attr.shard() - param_data = param.ca_attr.payload(torch.device('cpu')) - assert (numel_ref == param_data.numel()) + for _, param in model.named_parameters(): + param.col_attr.gather() + param_data = param.col_attr.payload(torch.device('cpu')) - for _, param in model.named_parameters(): - param.ca_attr.gather() - param_data = param.ca_attr.payload(torch.device('cpu')) - - disable_existing_loggers([logger]) + disable_existing_loggers([logger]) @pytest.mark.dist diff --git a/tests/test_zero_data_parallel/test_sharded_optim.py b/tests/test_zero_data_parallel/test_sharded_optim.py index def748f31..6720b8e39 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim.py +++ b/tests/test_zero_data_parallel/test_sharded_optim.py @@ -3,19 +3,13 @@ import colossalai import copy import pytest import torch.multiprocessing as mp -import torch.nn as nn from colossalai.zero import ShardedOptimizer from torch.nn.parallel import DistributedDataParallel as DDP from colossalai.utils import free_port from functools import partial - - -def check_equal(a, b): - """ - This function checks if two tensors are equal within tolerance - """ - assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}' +from common import allclose +from tests.components_to_test.registry import non_distributed_component_funcs def check_completely_equal(a, b): @@ -36,61 +30,56 @@ def check_sharded_param_consistency(): pg: partition gradients and optimizer states """ + test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] - # create layers - oss_linear1 = nn.Linear(128, 256) - oss_linear2 = nn.Linear(256, 512) + for name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(name) + model_builder, train_dataloader, *_ = get_components_func() - # create model - oss_model = nn.Sequential(oss_linear1, oss_linear2) - pg_model = copy.deepcopy(oss_model) + # create model + oss_model = model_builder(checkpoint=True).cuda().half() + pg_model = copy.deepcopy(oss_model) - oss_model = oss_model.cuda().half() - pg_model = pg_model.cuda().half() + # create optimizer + oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) + pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) + oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0) + pg_optimizer = ShardedOptimizer(pg_optimizer, + overlap_communication=True, + partition_grad=True, + initial_scale=1, + clip_grad_norm=0.0) - # create optimizer - oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) - pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) - oss_optimizer = ShardedOptimizer(oss_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0) - pg_optimizer = ShardedOptimizer(pg_optimizer, - overlap_communication=True, - partition_grad=True, - initial_scale=1, - clip_grad_norm=0.0) + # create + data, label = next(iter(train_dataloader)) + input_data = data.cuda().half() - # create - input_data = torch.rand(32, 128).cuda().half() + # forward + oss_output = oss_model(input_data) + pg_output = pg_model(input_data) + check_completely_equal(oss_output, pg_output) - # forward - oss_output = oss_model(input_data) - pg_output = pg_model(input_data) - check_completely_equal(oss_output, pg_output) + # backward + oss_optimizer.backward(oss_output.mean().float()) + pg_optimizer.backward(pg_output.mean().float()) - # backward - oss_optimizer.backward(oss_output.mean().float()) - pg_optimizer.backward(pg_output.mean().float()) + # check grad + # as this param is small, the backward reduction + # will not be fired + for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()): + check_completely_equal(oss_param.grad, pg_param.grad) - # check grad - # as this param is small, the backward reduction - # will not be fired - oss_linear1_grad = oss_model[0].weight.grad - oss_linear2_grad = oss_model[1].weight.grad - pg_linear1_grad = pg_model[0].weight.grad - pg_linear2_grad = pg_model[1].weight.grad - check_completely_equal(oss_linear1_grad, pg_linear1_grad) - check_completely_equal(oss_linear2_grad, pg_linear2_grad) + # step + oss_optimizer.sync_grad() + pg_optimizer.sync_grad() - # step - oss_optimizer.sync_grad() - pg_optimizer.sync_grad() + # step + oss_optimizer.step() + pg_optimizer.step() - # step - oss_optimizer.step() - pg_optimizer.step() - - # check updated param - check_completely_equal(oss_model[0].weight, pg_model[0].weight) - check_completely_equal(oss_model[1].weight, pg_model[1].weight) + # check updated param + for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()): + check_completely_equal(oss_param, pg_param) def check_sharded_optim_against_torch_ddp(): @@ -103,61 +92,62 @@ def check_sharded_optim_against_torch_ddp(): differences in model output and updated parameters are within tolerance. """ - # create layer - zero_linear1 = nn.Linear(128, 256) - zero_linear2 = nn.Linear(256, 512) + test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] - # create model - zero_model = nn.Sequential(zero_linear1, zero_linear2) - torch_model = copy.deepcopy(zero_model) + for name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(name) + model_builder, train_dataloader, *_ = get_components_func() - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda()) + # create model + zero_model = model_builder(checkpoint=True).cuda() + torch_model = copy.deepcopy(zero_model) - # create optimizer - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) + zero_model = zero_model.half() + torch_model = DDP(torch_model.cuda()) - # we only test stage 1 here - # in `check_sharded_param_consistency.py`, we will test whether - # level 1 and 2 will produce exactly the same results - zero_optimizer = ShardedOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0) + # create optimizer + zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=0.001) - torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) + # we only test stage 1 here + # in `check_sharded_param_consistency.py`, we will test whether + # level 1 and 2 will produce exactly the same results + zero_optimizer = ShardedOptimizer(zero_optimizer, + overlap_communication=True, + initial_scale=1, + clip_grad_norm=0.0) + torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) - # create - input_data = torch.rand(32, 128).cuda() + # create + input_data, _ = next(iter(train_dataloader)) + input_data = input_data.cuda() - # zero-dp forward - zero_output = zero_model(input_data.half()) + # zero-dp forward + zero_output = zero_model(input_data.half()) - # torch-ddp forward - torch_output = torch_model(input_data) - check_equal(zero_output, torch_output) + # torch-ddp forward + torch_output = torch_model(input_data) + allclose(zero_output, torch_output.half()) - # zero-dp backward - zero_optimizer.backward(zero_output.mean().float()) + # zero-dp backward + zero_optimizer.backward(zero_output.mean().float()) - # torch-ddp backward - torch_output.mean().backward() + # torch-ddp backward + torch_output.mean().backward() - # check grad - zero_linear1_grad = zero_model[0].weight.grad - zero_linear2_grad = zero_model[1].weight.grad - torch_linear1_grad = torch_model.module[0].weight.grad - torch_linear2_grad = torch_model.module[1].weight.grad - check_equal(zero_linear1_grad, torch_linear1_grad) - check_equal(zero_linear2_grad, torch_linear2_grad) + # check grad + for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + allclose(oss_param.grad, torch_param.grad.half()) - # zero-dp step - zero_optimizer.sync_grad() - zero_optimizer.step() + # zero-dp step + zero_optimizer.sync_grad() + zero_optimizer.step() - # torch ddp step - torch_optimizer.step() + # torch ddp step + torch_optimizer.step() - # check updated param - check_equal(zero_model[0].weight, torch_model.module[0].weight) - check_equal(zero_model[1].weight, torch_model.module[1].weight) + # check updated param + for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()): + allclose(oss_param, torch_param.half()) def run_dist(rank, world_size, port): diff --git a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py index d9b6524c8..41125e3a9 100644 --- a/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero_data_parallel/test_sharded_optim_with_sync_bn.py @@ -15,6 +15,8 @@ import torch.distributed as dist def run_dist(rank, world_size, port): + # this test only runs on resnet18 + # as this model has sync batch normalization # need to configure cudnn deterministic so that # randomness of convolution layers will be disabled colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True), diff --git a/tests/test_zero_data_parallel/test_state_dict.py b/tests/test_zero_data_parallel/test_state_dict.py index 6b6bd6b5d..a71f59c27 100644 --- a/tests/test_zero_data_parallel/test_state_dict.py +++ b/tests/test_zero_data_parallel/test_state_dict.py @@ -22,8 +22,8 @@ def run_dist(rank, world_size, port): test_models = ['repeated_computed_layers', 'resnet18'] for model_name in test_models: get_components_func = non_distributed_component_funcs.get_callable(model_name) - model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - model = model() + model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() + model = model_builder() shard_strategy = TensorShardStrategy() model = model.half().cuda() zero_model = ShardedModelV2(deepcopy(model), shard_strategy) diff --git a/tests/test_zero_data_parallel/test_zero_dev_3.py b/tests/test_zero_data_parallel/test_zero_dev_3.py deleted file mode 100644 index a6fd9df17..000000000 --- a/tests/test_zero_data_parallel/test_zero_dev_3.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import copy -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.logging import disable_existing_loggers -from colossalai.utils import checkpoint, free_port -from colossalai.zero.sharded_model import ShardedModel -from common import Net, check_grads, check_params, check_params - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward) - return module - - -class Net(nn.Module): - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.fc2 = nn.Linear(5, 5) - self.fc3 = nn.Linear(5, 1) - if checkpoint: - self.fc1 = checkpoint_wrapper(self.fc1) - self.layers = [ - self.fc1, - self.fc2, - self.fc1, - self.fc2, - self.fc3 - ] - - def forward(self, x): - for layer in self.layers: - x = layer(x) - return x - - -def run_step(model, optimizer, x, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(x) - loss = y.sum() - loss = loss.float() - loss.backward() - optimizer.step() - - -def decode_booleans(intval, bits): - res = [] - for bit in range(bits): - mask = 1 << bit - res.append((intval & mask) == mask) - return res - - -def check_config(checkpoint=False, fp16=False, offload=False): - model = Net(checkpoint=checkpoint).cuda() - zero_model = copy.deepcopy(model) - - offload_config = {} - if offload: - offload_config['device'] = 'cpu' - zero_model = zero_model.cpu() - zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config) - - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(model, optimizer, x, enable_autocast=fp16) - run_step(zero_model, zero_optimizer, x, enable_autocast=fp16) - check_grads(model, zero_model) - check_params(model, zero_model) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(model, optimizer, x, enable_autocast=False) - run_step(zero_model, zero_optimizer, x, enable_autocast=False) - check_grads(model, zero_model, loose=True) - check_params(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - args = ['checkpoint', 'fp16', 'offload'] - - def pack_args(i): - booleans = decode_booleans(i, len(args)) - return {arg: booleans[idx] for idx, arg in enumerate(args)} - - for j in range(2 ** len(args)): - kwargs = pack_args(j) - print(kwargs) - check_config(**kwargs) - - -@pytest.mark.dist -def test_zero_level_3(): - world_size = 1 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_level_3() diff --git a/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py b/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py deleted file mode 100644 index a3ce53eeb..000000000 --- a/tests/test_zero_data_parallel/test_zero_dev_3_mp4.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import copy -from functools import partial - -import colossalai -import pytest -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.logging import disable_existing_loggers -from colossalai.utils import checkpoint, free_port -from colossalai.zero.sharded_model import ShardedModel -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import Net, check_grads_padding, check_params_padding - - -def run_step(model, optimizer, x, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - y = model(x) - loss = y.sum() - loss = loss.float() - loss.backward() - optimizer.step() - - -def decode_booleans(intval, bits): - res = [] - for bit in range(bits): - mask = 1 << bit - res.append((intval & mask) == mask) - return res - - -def check_config(checkpoint=False, fp16=False, offload=False): - model = Net(checkpoint=checkpoint).cuda() - zero_model = copy.deepcopy(model) - ddp_model = DDP(model) - - offload_config = {} - if offload: - offload_config['device'] = 'cpu' - zero_model = zero_model.cpu() - zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config) - - optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3) - zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(ddp_model, optimizer, x, enable_autocast=fp16) - run_step(zero_model, zero_optimizer, x, enable_autocast=fp16) - check_grads_padding(ddp_model, zero_model) - check_params_padding(ddp_model, zero_model) - for _ in range(5): - x = torch.rand(2, 5).cuda() - run_step(ddp_model, optimizer, x, enable_autocast=False) - run_step(zero_model, zero_optimizer, x, enable_autocast=False) - check_grads_padding(ddp_model, zero_model, loose=True) - check_params_padding(ddp_model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - disable_existing_loggers() - colossalai.launch(config={}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - args = ['checkpoint', 'fp16', 'offload'] - - def pack_args(i): - booleans = decode_booleans(i, len(args)) - return {arg: booleans[idx] for idx, arg in enumerate(args)} - - for j in range(2 ** len(args)): - kwargs = pack_args(j) - if dist.get_rank() == 0: - print(kwargs) - check_config(**kwargs) - - -@pytest.mark.dist -def test_zero_level_3(): - world_size = 4 - run_func = partial(run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_zero_level_3()