mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 21:40:02 +00:00
[zero] update zero context init with the updated test utils (#327)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.nn import CheckpointModule
|
||||
from .utils import DummyDataGenerator
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
@@ -15,10 +16,10 @@ class SubNet(nn.Module):
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
class NestedNet(nn.Module):
|
||||
class NestedNet(CheckpointModule):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint)
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.sub_fc = SubNet(5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
@@ -41,9 +42,15 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
@non_distributed_component_funcs.register(name='nested_model')
|
||||
def get_training_components():
|
||||
model = NestedNet()
|
||||
|
||||
def model_builder(checkpoint):
|
||||
return NestedNet(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
|
@@ -36,9 +36,15 @@ class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
@non_distributed_component_funcs.register(name='repeated_computed_layers')
|
||||
def get_training_components():
|
||||
model = NetWithRepeatedlyComputedLayers(checkpoint=True)
|
||||
|
||||
def model_builder(checkpoint=True):
|
||||
return NetWithRepeatedlyComputedLayers(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
|
@@ -22,9 +22,15 @@ def get_cifar10_dataloader(train):
|
||||
|
||||
@non_distributed_component_funcs.register(name='resnet18')
|
||||
def get_resnet_training_components():
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
def model_builder(checkpoint=False):
|
||||
return resnet18(num_classes=10)
|
||||
|
||||
trainloader = get_cifar10_dataloader(train=True)
|
||||
testloader = get_cifar10_dataloader(train=False)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
def optim_builder(model):
|
||||
return torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
||||
return model_builder, trainloader, testloader, optim_builder, criterion
|
||||
|
Reference in New Issue
Block a user