mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[zero] refactor ShardedParamV2 for convenience (#742)
This commit is contained in:
@@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
@@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
|
||||
@@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
@@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
Reference in New Issue
Block a user