diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/gemini/stateful_tensor_mgr.py index 15f121710..9ee1a6805 100644 --- a/colossalai/gemini/stateful_tensor_mgr.py +++ b/colossalai/gemini/stateful_tensor_mgr.py @@ -6,7 +6,6 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy from typing import List -from colossalai.logging import get_dist_logger class StatefulTensorMgr(object): @@ -20,23 +19,30 @@ class StatefulTensorMgr(object): def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy self._stateful_tensor_list: List[StatefulTensor] = [] - self._logger = get_dist_logger("StatefulTensorMgr") - - self._warmup = True self._compute_list: List[StatefulTensor] = [] self._compute_idx: int = -1 self._cpu_gpu_move_volume = 0 + self._warmup = True - def register_stateful_param(self, param) -> None: - from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 - assert isinstance(param, ShardedParamV2) - for t in param.get_payload_tensors(): + def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None: + assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice" + self._stateful_tensor_list = tensor_list + for t in self._stateful_tensor_list: assert isinstance(t, StatefulTensor) - self._stateful_tensor_list.append(t) t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) + def start_iter(self): + pass + + def finish_iter(self): + """This function must be called when each iteration finishes + """ + self._warmup = False + self._compute_idx = -1 + self._cpu_gpu_move_volume = 0 + def adjust_layout(self) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. @@ -63,21 +69,14 @@ class StatefulTensorMgr(object): compute_list=self._compute_list, compute_idx=self._compute_idx) # move COMPUTE tensors to CUDA + self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: colo_model_data_tensor_move_inline(t, get_current_device()) - self._cpu_gpu_move_volume += t.payload_size @property def cpu_gpu_move_volume(self): return self._cpu_gpu_move_volume - def reset(self): - """This function must be called when each iteration finishes - """ - self._warmup = False - self._compute_idx = -1 - self._cpu_gpu_move_volume = 0 - def _trans_state(self, trans_state_func, stateful_tensor, state): trans_state_func(state) if state == TensorState.COMPUTE: diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 0f958aaea..cc37ddf17 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -111,10 +111,10 @@ class ShardedModelV2(nn.Module): self._memstats_collector = None self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create( tensor_placement_policy)(mem_stats_collector=self._memstats_collector) + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) - for param in module.parameters(): - if hasattr(param, 'colo_attr'): - self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) + param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')] + self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list) # Register hooks self._ophook_list = [ @@ -198,6 +198,8 @@ class ShardedModelV2(nn.Module): if hasattr(p, 'colo_attr'): p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + self._stateful_tensor_mgr.start_iter() + def _post_forward_operations(self): for p in self.module.parameters(): if hasattr(p, 'colo_attr'): diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index 5aa9da158..384617030 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -115,4 +115,4 @@ class ZeroHook(BaseOpHook): if self._stateful_tensor_mgr: self.logger.info( f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0]) - self._stateful_tensor_mgr.reset() + self._stateful_tensor_mgr.finish_iter()