[zero] new interface for ShardedOptimv2 (#406)

This commit is contained in:
Jiarui Fang
2022-03-14 20:48:41 +08:00
committed by GitHub
parent a9c27be42e
commit 370f567e7d
9 changed files with 51 additions and 35 deletions

View File

@@ -74,8 +74,5 @@ def get_training_components():
sequence_length=sequence_length,
is_distrbuted=True)
def get_optim(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = None
return bert_model_builder, trainloader, testloader, get_optim, criterion
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@@ -49,8 +49,5 @@ def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@@ -43,8 +43,5 @@ def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion
return model_builder, trainloader, testloader, torch.optim.Adam, criterion