[unit test] Refactored test cases with component func (#339)

* refactored test with component func

* fixed bug
This commit is contained in:
Frank Lee 2022-03-11 14:09:09 +08:00
parent de46450461
commit 526a318032
11 changed files with 148 additions and 420 deletions

View File

@ -43,7 +43,7 @@ class DummyDataLoader(DummyDataGenerator):
@non_distributed_component_funcs.register(name='nested_model') @non_distributed_component_funcs.register(name='nested_model')
def get_training_components(): def get_training_components():
def model_builder(checkpoint): def model_builder(checkpoint=True):
return NestedNet(checkpoint) return NestedNet(checkpoint)
trainloader = DummyDataLoader() trainloader = DummyDataLoader()

View File

@ -3,12 +3,23 @@ from abc import ABC, abstractmethod
class DummyDataGenerator(ABC): class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod @abstractmethod
def generate(self): def generate(self):
pass pass
def __iter__(self): def __iter__(self):
self.step = 0
return self return self
def __next__(self): def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate() return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length

View File

@ -1,21 +1,14 @@
import os
from functools import partial from functools import partial
from pathlib import Path
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.amp.amp_type import AMP_TYPE from colossalai.amp.amp_type import AMP_TYPE
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.utils import MultiTimer, free_port, get_dataloader from colossalai.utils import MultiTimer, free_port
from torch.optim import Adam from tests.components_to_test.registry import non_distributed_component_funcs
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
BATCH_SIZE = 16 BATCH_SIZE = 16
IMG_SIZE = 32 IMG_SIZE = 32
@ -29,40 +22,13 @@ CONFIG = dict(
def run_trainer_no_pipeline(rank, world_size, port): def run_trainer_no_pipeline(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# build model test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
model = resnet18(num_classes=10) for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
# build dataloaders model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func()
train_dataset = CIFAR10(root=Path(os.environ['DATA']), model = model_builder()
download=True, optimizer = optimizer_builder(model)
transform=transforms.Compose([ engine, train_dataloader, *_ = colossalai.initialize(model=model,
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
test_dataset = CIFAR10(root=Path(os.environ['DATA']),
train=False,
download=True,
transform=transforms.Compose([
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
]))
train_dataloader = get_dataloader(dataset=train_dataset,
shuffle=True,
batch_size=BATCH_SIZE,
pin_memory=True,
drop_last=True)
test_dataloader = get_dataloader(dataset=test_dataset, batch_size=BATCH_SIZE, pin_memory=True, drop_last=True)
# build optimizer
optimizer = Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
@ -78,10 +44,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
trainer.fit(train_dataloader=train_dataloader, trainer.fit(train_dataloader=train_dataloader,
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
epochs=NUM_EPOCHS, epochs=NUM_EPOCHS,
max_steps=100, max_steps=5,
display_progress=True, display_progress=True,
test_interval=5) test_interval=5)
gpc.destroy()
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -20,6 +20,7 @@ from torch.nn.utils import clip_grad_norm_
class Enumerator: class Enumerator:
def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None: def __init__(self, arg_names: List[str], arg_values: List[tuple]) -> None:
self.arg_names = arg_names self.arg_names = arg_names
self.enums = Enumerator.all_enumerate(arg_values) self.enums = Enumerator.all_enumerate(arg_values)
@ -49,11 +50,12 @@ class Enumerator:
def checkpoint_wrapper(module, enable=True): def checkpoint_wrapper(module, enable=True):
if enable: if enable:
module.forward = partial(checkpoint, module.forward) module.forward = partial(checkpoint, module.forward, False)
return module return module
class Net(nn.Module): class Net(nn.Module):
def __init__(self, checkpoint=False) -> None: def __init__(self, checkpoint=False) -> None:
super().__init__() super().__init__()
self.fc1 = nn.Linear(5, 5) self.fc1 = nn.Linear(5, 5)
@ -61,13 +63,7 @@ class Net(nn.Module):
self.fc3 = nn.Linear(5, 1) self.fc3 = nn.Linear(5, 1)
if checkpoint: if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1) self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [ self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x): def forward(self, x):
for layer in self.layers: for layer in self.layers:
@ -158,12 +154,7 @@ def check_config(checkpoint=False, fp16=False, offload=False, norm_type=2.0):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config={}, colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload', 'norm_type'] args = ['checkpoint', 'fp16', 'offload', 'norm_type']
arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))] arg_values = [(False, True), (False, True), (False, True), (1.0, 2.0, float('inf'))]
@ -176,7 +167,7 @@ def run_dist(rank, world_size, port):
check_config() check_config()
@ pytest.mark.dist @pytest.mark.dist
def test_zero_clip_grad(): def test_zero_clip_grad():
world_size = 4 world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())

View File

@ -43,23 +43,6 @@ def checkpoint_wrapper(module, enable=True):
return module return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool:
if loose: if loose:
return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3)

View File

@ -13,7 +13,8 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero_data_parallel.common import CONFIG, Net, allclose from tests.test_zero_data_parallel.common import CONFIG, allclose
from tests.components_to_test.registry import non_distributed_component_funcs
def _run_shard_tensor(rank, world_size, port): def _run_shard_tensor(rank, world_size, port):
@ -68,19 +69,20 @@ def _run_test_shard_param(rank, world_size, port):
print(param_ref.data) print(param_ref.data)
logger = get_dist_logger() logger = get_dist_logger()
model = Net() for get_components_func in non_distributed_component_funcs:
model_builder, *_ = get_components_func()
# add an attribute as ca_attr to hijack the access to param.data model = model_builder(checkpoint=True)
# add an attribute as col_attr to hijack the access to param.data
for _, param in model.named_parameters(): for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size numel_ref = (param.numel() + world_size - 1) // world_size
param.ca_attr = ShardedParam(param) param.col_attr = ShardedParam(param)
param.ca_attr.shard() param.col_attr.shard()
param_data = param.ca_attr.payload(torch.device('cpu')) param_data = param.col_attr.payload(torch.device('cpu'))
assert (numel_ref == param_data.numel()) assert (numel_ref == param_data.numel())
for _, param in model.named_parameters(): for _, param in model.named_parameters():
param.ca_attr.gather() param.col_attr.gather()
param_data = param.ca_attr.payload(torch.device('cpu')) param_data = param.col_attr.payload(torch.device('cpu'))
disable_existing_loggers([logger]) disable_existing_loggers([logger])

View File

@ -3,19 +3,13 @@ import colossalai
import copy import copy
import pytest import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.zero import ShardedOptimizer from colossalai.zero import ShardedOptimizer
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from colossalai.utils import free_port from colossalai.utils import free_port
from functools import partial from functools import partial
from common import allclose
from tests.components_to_test.registry import non_distributed_component_funcs
def check_equal(a, b):
"""
This function checks if two tensors are equal within tolerance
"""
assert torch.allclose(a.float(), b.float(), rtol=1e-4, atol=1e-3), f'a = {a}, b = {b}'
def check_completely_equal(a, b): def check_completely_equal(a, b):
@ -36,18 +30,16 @@ def check_sharded_param_consistency():
pg: partition gradients and optimizer states pg: partition gradients and optimizer states
""" """
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
# create layers for name in test_models:
oss_linear1 = nn.Linear(128, 256) get_components_func = non_distributed_component_funcs.get_callable(name)
oss_linear2 = nn.Linear(256, 512) model_builder, train_dataloader, *_ = get_components_func()
# create model # create model
oss_model = nn.Sequential(oss_linear1, oss_linear2) oss_model = model_builder(checkpoint=True).cuda().half()
pg_model = copy.deepcopy(oss_model) pg_model = copy.deepcopy(oss_model)
oss_model = oss_model.cuda().half()
pg_model = pg_model.cuda().half()
# create optimizer # create optimizer
oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001) oss_optimizer = torch.optim.Adam(oss_model.parameters(), lr=0.001)
pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001) pg_optimizer = torch.optim.Adam(pg_model.parameters(), lr=0.001)
@ -59,7 +51,8 @@ def check_sharded_param_consistency():
clip_grad_norm=0.0) clip_grad_norm=0.0)
# create # create
input_data = torch.rand(32, 128).cuda().half() data, label = next(iter(train_dataloader))
input_data = data.cuda().half()
# forward # forward
oss_output = oss_model(input_data) oss_output = oss_model(input_data)
@ -73,12 +66,8 @@ def check_sharded_param_consistency():
# check grad # check grad
# as this param is small, the backward reduction # as this param is small, the backward reduction
# will not be fired # will not be fired
oss_linear1_grad = oss_model[0].weight.grad for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
oss_linear2_grad = oss_model[1].weight.grad check_completely_equal(oss_param.grad, pg_param.grad)
pg_linear1_grad = pg_model[0].weight.grad
pg_linear2_grad = pg_model[1].weight.grad
check_completely_equal(oss_linear1_grad, pg_linear1_grad)
check_completely_equal(oss_linear2_grad, pg_linear2_grad)
# step # step
oss_optimizer.sync_grad() oss_optimizer.sync_grad()
@ -89,8 +78,8 @@ def check_sharded_param_consistency():
pg_optimizer.step() pg_optimizer.step()
# check updated param # check updated param
check_completely_equal(oss_model[0].weight, pg_model[0].weight) for oss_param, pg_param in zip(oss_model.parameters(), pg_model.parameters()):
check_completely_equal(oss_model[1].weight, pg_model[1].weight) check_completely_equal(oss_param, pg_param)
def check_sharded_optim_against_torch_ddp(): def check_sharded_optim_against_torch_ddp():
@ -103,15 +92,17 @@ def check_sharded_optim_against_torch_ddp():
differences in model output and updated parameters are within tolerance. differences in model output and updated parameters are within tolerance.
""" """
# create layer test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
zero_linear1 = nn.Linear(128, 256)
zero_linear2 = nn.Linear(256, 512) for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, *_ = get_components_func()
# create model # create model
zero_model = nn.Sequential(zero_linear1, zero_linear2) zero_model = model_builder(checkpoint=True).cuda()
torch_model = copy.deepcopy(zero_model) torch_model = copy.deepcopy(zero_model)
zero_model = zero_model.cuda().half() zero_model = zero_model.half()
torch_model = DDP(torch_model.cuda()) torch_model = DDP(torch_model.cuda())
# create optimizer # create optimizer
@ -120,19 +111,22 @@ def check_sharded_optim_against_torch_ddp():
# we only test stage 1 here # we only test stage 1 here
# in `check_sharded_param_consistency.py`, we will test whether # in `check_sharded_param_consistency.py`, we will test whether
# level 1 and 2 will produce exactly the same results # level 1 and 2 will produce exactly the same results
zero_optimizer = ShardedOptimizer(zero_optimizer, overlap_communication=True, initial_scale=1, clip_grad_norm=0.0) zero_optimizer = ShardedOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
clip_grad_norm=0.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001) torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=0.001)
# create # create
input_data = torch.rand(32, 128).cuda() input_data, _ = next(iter(train_dataloader))
input_data = input_data.cuda()
# zero-dp forward # zero-dp forward
zero_output = zero_model(input_data.half()) zero_output = zero_model(input_data.half())
# torch-ddp forward # torch-ddp forward
torch_output = torch_model(input_data) torch_output = torch_model(input_data)
check_equal(zero_output, torch_output) allclose(zero_output, torch_output.half())
# zero-dp backward # zero-dp backward
zero_optimizer.backward(zero_output.mean().float()) zero_optimizer.backward(zero_output.mean().float())
@ -141,12 +135,8 @@ def check_sharded_optim_against_torch_ddp():
torch_output.mean().backward() torch_output.mean().backward()
# check grad # check grad
zero_linear1_grad = zero_model[0].weight.grad for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
zero_linear2_grad = zero_model[1].weight.grad allclose(oss_param.grad, torch_param.grad.half())
torch_linear1_grad = torch_model.module[0].weight.grad
torch_linear2_grad = torch_model.module[1].weight.grad
check_equal(zero_linear1_grad, torch_linear1_grad)
check_equal(zero_linear2_grad, torch_linear2_grad)
# zero-dp step # zero-dp step
zero_optimizer.sync_grad() zero_optimizer.sync_grad()
@ -156,8 +146,8 @@ def check_sharded_optim_against_torch_ddp():
torch_optimizer.step() torch_optimizer.step()
# check updated param # check updated param
check_equal(zero_model[0].weight, torch_model.module[0].weight) for oss_param, torch_param in zip(zero_model.parameters(), torch_model.parameters()):
check_equal(zero_model[1].weight, torch_model.module[1].weight) allclose(oss_param, torch_param.half())
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):

View File

@ -15,6 +15,8 @@ import torch.distributed as dist
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
# this test only runs on resnet18
# as this model has sync batch normalization
# need to configure cudnn deterministic so that # need to configure cudnn deterministic so that
# randomness of convolution layers will be disabled # randomness of convolution layers will be disabled
colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True), colossalai.launch(config=dict(zero=dict(level=2, partition_grad=True),

View File

@ -22,8 +22,8 @@ def run_dist(rank, world_size, port):
test_models = ['repeated_computed_layers', 'resnet18'] test_models = ['repeated_computed_layers', 'resnet18']
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
model = model() model = model_builder()
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
model = model.half().cuda() model = model.half().cuda()
zero_model = ShardedModelV2(deepcopy(model), shard_strategy) zero_model = ShardedModelV2(deepcopy(model), shard_strategy)

View File

@ -1,119 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from common import Net, check_grads, check_params, check_params
def checkpoint_wrapper(module, enable=True):
if enable:
module.forward = partial(checkpoint, module.forward)
return module
class Net(nn.Module):
def __init__(self, checkpoint=False) -> None:
super().__init__()
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 1)
if checkpoint:
self.fc1 = checkpoint_wrapper(self.fc1)
self.layers = [
self.fc1,
self.fc2,
self.fc1,
self.fc2,
self.fc3
]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads(model, zero_model)
check_params(model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads(model, zero_model, loose=True)
check_params(model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 1
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()

View File

@ -1,97 +0,0 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import copy
from functools import partial
import colossalai
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
from colossalai.logging import disable_existing_loggers
from colossalai.utils import checkpoint, free_port
from colossalai.zero.sharded_model import ShardedModel
from torch.nn.parallel import DistributedDataParallel as DDP
from common import Net, check_grads_padding, check_params_padding
def run_step(model, optimizer, x, enable_autocast=False):
model.train()
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=enable_autocast):
y = model(x)
loss = y.sum()
loss = loss.float()
loss.backward()
optimizer.step()
def decode_booleans(intval, bits):
res = []
for bit in range(bits):
mask = 1 << bit
res.append((intval & mask) == mask)
return res
def check_config(checkpoint=False, fp16=False, offload=False):
model = Net(checkpoint=checkpoint).cuda()
zero_model = copy.deepcopy(model)
ddp_model = DDP(model)
offload_config = {}
if offload:
offload_config['device'] = 'cpu'
zero_model = zero_model.cpu()
zero_model = ShardedModel(zero_model, mixed_precision=fp16, offload_config=offload_config)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1e-3)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=fp16)
run_step(zero_model, zero_optimizer, x, enable_autocast=fp16)
check_grads_padding(ddp_model, zero_model)
check_params_padding(ddp_model, zero_model)
for _ in range(5):
x = torch.rand(2, 5).cuda()
run_step(ddp_model, optimizer, x, enable_autocast=False)
run_step(zero_model, zero_optimizer, x, enable_autocast=False)
check_grads_padding(ddp_model, zero_model, loose=True)
check_params_padding(ddp_model, zero_model, loose=True)
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={},
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
args = ['checkpoint', 'fp16', 'offload']
def pack_args(i):
booleans = decode_booleans(i, len(args))
return {arg: booleans[idx] for idx, arg in enumerate(args)}
for j in range(2 ** len(args)):
kwargs = pack_args(j)
if dist.get_rank() == 0:
print(kwargs)
check_config(**kwargs)
@pytest.mark.dist
def test_zero_level_3():
world_size = 4
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_zero_level_3()