[zero] adapt for no-leaf module in zero (#535)

only process module's own parameters in Zero context

add zero hooks for all modules that contrain parameters

gather parameters only belonging to module itself
This commit is contained in:
HELSON 2022-03-28 17:42:18 +08:00 committed by GitHub
parent 705f56107c
commit a30e2b4c24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 70 additions and 26 deletions

View File

@ -64,18 +64,13 @@ class PostBackwardFunction(torch.autograd.Function):
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""): def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD.""" r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module) assert isinstance(module, torch.nn.Module)
has_children = False
# Add hooks for submodules
for child_name, child in module.named_children(): for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name) register_ophooks_recursively(child, ophook_list, name + child_name)
has_children = True
# Early return on modules with no parameters or buffers that # Early return on modules with no parameters.
# are not in their children. if len(list(module.parameters(recurse=False))) == 0:
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
return
# return if the module has not childern.
if has_children:
return return
if ophook_list is not None: if ophook_list is not None:

View File

@ -31,11 +31,11 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = [] tensor_list = []
for param in module.parameters(): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload param.data = param.col_attr.sharded_data_tensor.payload
@ -44,20 +44,20 @@ class ZeroHook(BaseOpHook):
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = [] tensor_list = []
for param in module.parameters(): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = [] tensor_list = []
for param in module.parameters(): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload param.data = param.col_attr.sharded_data_tensor.payload
# Store local accumulated grad shard # Store local accumulated grad shard
@ -77,11 +77,11 @@ class ZeroHook(BaseOpHook):
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
tensor_list = [] tensor_list = []
for param in module.parameters(): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(): for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
def pre_iter(self): def pre_iter(self):

View File

@ -12,6 +12,12 @@ from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers from colossalai.logging import get_dist_logger, disable_existing_loggers
def _substitute_init_recursively(cls, func):
for subcls in cls.__subclasses__():
_substitute_init_recursively(subcls, func)
func(subcls)
class InsertPostInitMethodToModuleSubClasses(object): class InsertPostInitMethodToModuleSubClasses(object):
def __init__(self): def __init__(self):
@ -41,8 +47,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
# Replace .__init__() for all existing subclasses of torch.nn.Module # Replace .__init__() for all existing subclasses of torch.nn.Module
# Excution self._post_init_method after the default init function. # Excution self._post_init_method after the default init function.
for subclass in torch.nn.modules.module.Module.__subclasses__(): _substitute_init_recursively(torch.nn.modules.module.Module, _enable_class)
_enable_class(subclass)
# holding on to the current __init__subclass__ for exit # holding on to the current __init__subclass__ for exit
torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__) torch.nn.modules.module.Module._old_init_subclass = (torch.nn.modules.module.Module.__init_subclass__)
@ -57,8 +62,7 @@ class InsertPostInitMethodToModuleSubClasses(object):
cls.__init__ = cls._old_init cls.__init__ = cls._old_init
# Replace .__init__() for all existing subclasses of torch.nn.Module # Replace .__init__() for all existing subclasses of torch.nn.Module
for subclass in torch.nn.modules.module.Module.__subclasses__(): _substitute_init_recursively(torch.nn.modules.module.Module, _disable_class)
_disable_class(subclass)
# Replace .__init__() for future subclasses of torch.nn.Module # Replace .__init__() for future subclasses of torch.nn.Module
torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass) torch.nn.modules.module.Module.__init_subclass__ = (torch.nn.modules.module.Module._old_init_subclass)
@ -144,7 +148,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.
""" """
for param in module.parameters(): for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice # avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'): if hasattr(param, 'col_attr'):
continue continue
@ -173,7 +177,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# We must cast buffers # We must cast buffers
# If we use BN, buffers may be on CPU and Float # If we use BN, buffers may be on CPU and Float
# We must cast them # We must cast them
for buffer in module.buffers(): for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device()) buffer.data = buffer.data.to(device=torch.cuda.current_device())
if self.convert_fp16: if self.convert_fp16:
buffer.data = cast_tensor_to_fp16(buffer.data) buffer.data = cast_tensor_to_fp16(buffer.data)

View File

@ -1 +1 @@
from . import repeated_computed_layer, resnet, nested_model, bert from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module

View File

@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.nn import CheckpointModule
from .utils.dummy_data_generator import DummyDataGenerator
from .registry import non_distributed_component_funcs
class NoLeafModule(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name='no_leaf_module')
def get_training_components():
def model_builder(checkpoint=True):
return NoLeafModule(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -24,7 +24,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True]) @parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class): def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy_class()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)

View File

@ -45,7 +45,7 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) @parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers', 'resnet18', 'bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class() shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False: if use_cpuadam and cpu_offload is False: