mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 18:09:06 +00:00
[zero] memtracer to record cuda memory usage of model data and overall system (#395)
This commit is contained in:
@@ -4,6 +4,9 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import ModelDataTracer
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
@@ -12,14 +15,17 @@ class ZeroHook(BaseOpHook):
|
||||
A hook to process sharded param for ZeRO method.
|
||||
"""
|
||||
|
||||
def __init__(self, shard_strategy: BaseShardStrategy):
|
||||
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
@@ -27,8 +33,12 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
@@ -40,6 +50,7 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
global_model_data_tracer = ModelDataTracer()
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
@@ -47,6 +58,7 @@ class ZeroHook(BaseOpHook):
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
global_model_data_tracer.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
@@ -60,6 +72,8 @@ class ZeroHook(BaseOpHook):
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
tensor_list = []
|
||||
|
Reference in New Issue
Block a user