mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 09:59:38 +00:00
[zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager * add eviction strategy * polish code * polish code * polish comment * add unit test * fix sampler bug * polish code * fix max sampling cnt resetting bug * fix sampler bug * polish code * fix bug * fix unit test Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
||||
@@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook):
|
||||
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector],
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.process_group = process_group
|
||||
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
else:
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
@@ -60,19 +71,27 @@ class ZeroHook(BaseOpHook):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
else:
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
@@ -91,4 +110,5 @@ class ZeroHook(BaseOpHook):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.reset()
|
||||
|
Reference in New Issue
Block a user