[zero] refactor ShardedParamV2 for convenience (#742)

This commit is contained in:
HELSON
2022-04-13 14:54:26 +08:00
committed by GitHub
parent 340e59f968
commit 22c4b88d56
16 changed files with 98 additions and 61 deletions

View File

@@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
@@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
@@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
@@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_iter(self):
pass