mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-05 15:44:49 +00:00
[zero] update zero context init with the updated test utils (#327)
This commit is contained in:
@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
|
||||
|
||||
@non_distributed_component_funcs.register(name='resnet18')
|
||||
def get_resnet_training_components():
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
def model_builder(checkpoint=False):
|
||||
return resnet18(num_classes=10)
|
||||
|
||||
trainloader = get_cifar10_dataloader(train=True)
|
||||
testloader = get_cifar10_dataloader(train=False)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
|
||||
Reference in New Issue
Block a user