[zero] adapt zero hooks for unsharded module (#699)

This commit is contained in:
HELSON
2022-04-08 20:23:26 +08:00
committed by GitHub
parent 896ade15d6
commit ee112fe1da
12 changed files with 71 additions and 59 deletions

View File

@@ -36,6 +36,7 @@ class ZeroHook(BaseOpHook):
self._stateful_tensor_mgr = stateful_tensor_mgr
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
@@ -45,12 +46,15 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
@@ -59,18 +63,25 @@ class ZeroHook(BaseOpHook):
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
# change tensor state to HOLD_AFTER_FWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
@@ -80,12 +91,15 @@ class ZeroHook(BaseOpHook):
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
# record memory statistics
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
@@ -94,15 +108,20 @@ class ZeroHook(BaseOpHook):
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
# change tensor state to HOLD_AFTER_BWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()