[polish] use GLOBAL_MODEL_DATA_TRACER (#417)

This commit is contained in:
Jiarui Fang
2022-03-15 11:29:46 +08:00
committed by GitHub
parent 23ba3fc450
commit 56bb412e72
8 changed files with 25 additions and 25 deletions

View File

@@ -5,7 +5,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Optional
@@ -25,7 +25,6 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
@@ -33,7 +32,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
if self._memstarts_collector:
@@ -50,7 +49,6 @@ class ZeroHook(BaseOpHook):
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = []
global_model_data_tracer = ModelDataTracer()
for param in module.parameters():
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.data)
@@ -58,7 +56,7 @@ class ZeroHook(BaseOpHook):
for param in module.parameters():
if param.col_attr.data.device != self.computing_device:
param.col_attr.data.to(self.computing_device)
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
param.data = param.col_attr.data.payload
# Store local accumulated grad shard
if param.grad is not None: