mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-25 01:40:08 +00:00
only process module's own parameters in Zero context add zero hooks for all modules that contrain parameters gather parameters only belonging to module itself
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from colossalai.nn import CheckpointModule
|
|
from .utils.dummy_data_generator import DummyDataGenerator
|
|
from .registry import non_distributed_component_funcs
|
|
|
|
|
|
class NoLeafModule(CheckpointModule):
|
|
"""
|
|
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
|
|
"""
|
|
|
|
def __init__(self, checkpoint=False) -> None:
|
|
super().__init__(checkpoint=checkpoint)
|
|
self.proj1 = nn.Linear(4, 8)
|
|
self.weight = nn.Parameter(torch.randn(8, 8))
|
|
self.proj2 = nn.Linear(8, 4)
|
|
|
|
def forward(self, x):
|
|
x = self.proj1(x)
|
|
x = F.linear(x, self.weight)
|
|
x = self.proj2(x)
|
|
return x
|
|
|
|
|
|
class DummyDataLoader(DummyDataGenerator):
|
|
|
|
def generate(self):
|
|
data = torch.rand(16, 4)
|
|
label = torch.randint(low=0, high=2, size=(16,))
|
|
return data, label
|
|
|
|
|
|
@non_distributed_component_funcs.register(name='no_leaf_module')
|
|
def get_training_components():
|
|
|
|
def model_builder(checkpoint=True):
|
|
return NoLeafModule(checkpoint)
|
|
|
|
trainloader = DummyDataLoader()
|
|
testloader = DummyDataLoader()
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|