mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[zero] adapt for no-leaf module in zero (#535)
only process module's own parameters in Zero context add zero hooks for all modules that contrain parameters gather parameters only belonging to module itself
This commit is contained in:
@@ -64,18 +64,13 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
def register_ophooks_recursively(module: torch.nn.Module, ophook_list: List[BaseOpHook] = None, name: str = ""):
|
||||
r"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
has_children = False
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name)
|
||||
has_children = True
|
||||
|
||||
# Early return on modules with no parameters or buffers that
|
||||
# are not in their children.
|
||||
if (len(list(module.named_parameters(recurse=False))) == 0 and len(list(module.named_buffers(recurse=False))) == 0):
|
||||
return
|
||||
|
||||
# return if the module has not childern.
|
||||
if has_children:
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
if ophook_list is not None:
|
||||
|
@@ -31,11 +31,11 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
|
||||
@@ -44,20 +44,20 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
@@ -77,11 +77,11 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
||||
def pre_iter(self):
|
||||
|
Reference in New Issue
Block a user