[zero] update zero context init with the updated test utils (#327)

This commit is contained in:
Jiarui Fang
2022-03-08 14:45:01 +08:00
committed by Frank Lee
parent 6268446b81
commit 11bddb6e55
10 changed files with 96 additions and 49 deletions

View File

@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
@non_distributed_component_funcs.register(name='resnet18')
def get_resnet_training_components():
model = resnet18(num_classes=10)
def model_builder(checkpoint=False):
return resnet18(num_classes=10)
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
optim = torch.optim.Adam(model.parameters(), lr=0.001)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model, trainloader, testloader, optim, criterion
return model_builder, trainloader, testloader, optim_builder, criterion