mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[zero] adapt for no-leaf module in zero (#535)
only process module's own parameters in Zero context add zero hooks for all modules that contrain parameters gather parameters only belonging to module itself
This commit is contained in:
@@ -1 +1 @@
|
||||
from . import repeated_computed_layer, resnet, nested_model, bert
|
||||
from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module
|
||||
|
45
tests/components_to_test/no_leaf_module.py
Normal file
45
tests/components_to_test/no_leaf_module.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from colossalai.nn import CheckpointModule
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
class NoLeafModule(CheckpointModule):
|
||||
"""
|
||||
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.proj1 = nn.Linear(4, 8)
|
||||
self.weight = nn.Parameter(torch.randn(8, 8))
|
||||
self.proj2 = nn.Linear(8, 4)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj1(x)
|
||||
x = F.linear(x, self.weight)
|
||||
x = self.proj2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.rand(16, 4)
|
||||
label = torch.randint(low=0, high=2, size=(16,))
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='no_leaf_module')
|
||||
def get_training_components():
|
||||
|
||||
def model_builder(checkpoint=True):
|
||||
return NoLeafModule(checkpoint)
|
||||
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
|
@@ -24,7 +24,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
for model_name in test_models:
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
|
@@ -45,7 +45,7 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
|
Reference in New Issue
Block a user