[zero] update zero context init with the updated test utils (#327)

This commit is contained in:
Jiarui Fang
2022-03-08 14:45:01 +08:00
committed by Frank Lee
parent 6268446b81
commit 11bddb6e55
10 changed files with 96 additions and 49 deletions

View File

@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.nn import CheckpointModule
from .utils import DummyDataGenerator
from .registry import non_distributed_component_funcs
@@ -15,10 +16,10 @@ class SubNet(nn.Module):
return F.linear(x, weight, self.bias)
class NestedNet(nn.Module):
class NestedNet(CheckpointModule):
def __init__(self) -> None:
super().__init__()
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint)
self.fc1 = nn.Linear(5, 5)
self.sub_fc = SubNet(5)
self.fc2 = nn.Linear(5, 2)
@@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
@non_distributed_component_funcs.register(name='nested_model')
def get_training_components():
model = NestedNet()
def model_builder(checkpoint):
return NestedNet(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model, trainloader, testloader, optim, criterion
return model_builder, trainloader, testloader, optim_builder, criterion

View File

@@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
@non_distributed_component_funcs.register(name='repeated_computed_layers')
def get_training_components():
model = NetWithRepeatedlyComputedLayers(checkpoint=True)
def model_builder(checkpoint=True):
return NetWithRepeatedlyComputedLayers(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
optim = torch.optim.Adam(model.parameters(), lr=0.001)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model, trainloader, testloader, optim, criterion
return model_builder, trainloader, testloader, optim_builder, criterion

View File

@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
@non_distributed_component_funcs.register(name='resnet18')
def get_resnet_training_components():
model = resnet18(num_classes=10)
def model_builder(checkpoint=False):
return resnet18(num_classes=10)
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
optim = torch.optim.Adam(model.parameters(), lr=0.001)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model, trainloader, testloader, optim, criterion
return model_builder, trainloader, testloader, optim_builder, criterion

View File

@@ -16,10 +16,11 @@ CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None
def run_train():
for get_components_func in non_distributed_component_funcs:
model, train_dataloader, _, optimizer, criterion = get_components_func()
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func()
model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer,
optimizer=optimizer_builder(model),
criterion=criterion,
train_dataloader=train_dataloader)

View File

@@ -9,22 +9,27 @@ import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
from colossalai.zero.init_ctx import ZeroInitContext
from common import CONFIG, Net
from common import CONFIG
from colossalai.utils import free_port
from tests.components_to_test.registry import non_distributed_component_funcs
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
with ZeroInitContext(convert_fp16=True, convert_cuda=True, shard_strategy=TensorShardStrategy(), shard_param=True):
# Note Net(checkpoint=True).cuda() moving to cuda is useless
model = Net(checkpoint=True)
for get_components_func in non_distributed_component_funcs:
model_builder, _, _, _, _ = get_components_func()
with ZeroInitContext(convert_fp16=True,
convert_cuda=True,
shard_strategy=TensorShardStrategy(),
shard_param=True):
model = model_builder(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
for param in model.parameters():
assert hasattr(param, 'ca_attr')
assert param.ca_attr.data.dtype == torch.half
assert param.ca_attr._data_sharded_tensor.is_sharded
assert param.ca_attr.data.device.type == 'cuda'
@pytest.mark.dist

View File

@@ -46,6 +46,8 @@ def _run_shard_param_v2(rank, world_size, port):
sparam = ShardedParamV2(param=param, process_group=None)
allclose(sparam.data, param_ref.data)
sparam.remove_torch_payload()
assert (param.data.numel() == 1)