mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-12 12:47:21 +00:00
[zero] new interface for ShardedOptimv2 (#406)
This commit is contained in:
@@ -74,8 +74,5 @@ def get_training_components():
|
||||
sequence_length=sequence_length,
|
||||
is_distrbuted=True)
|
||||
|
||||
def get_optim(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = None
|
||||
return bert_model_builder, trainloader, testloader, get_optim, criterion
|
||||
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||
|
@@ -49,8 +49,5 @@ def get_training_components():
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||
|
@@ -43,8 +43,5 @@ def get_training_components():
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||
|
@@ -29,8 +29,5 @@ def get_resnet_training_components():
|
||||
trainloader = get_cifar10_dataloader(train=True)
|
||||
testloader = get_cifar10_dataloader(train=False)
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
||||
|
Reference in New Issue
Block a user