mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 06:30:41 +00:00
[gemini] polish stateful_tensor_mgr (#876)
This commit is contained in:
@@ -111,10 +111,10 @@ class ShardedModelV2(nn.Module):
|
||||
self._memstats_collector = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
@@ -198,6 +198,8 @@ class ShardedModelV2(nn.Module):
|
||||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
self._stateful_tensor_mgr.start_iter()
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
|
@@ -115,4 +115,4 @@ class ZeroHook(BaseOpHook):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.info(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
||||
self._stateful_tensor_mgr.reset()
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
Reference in New Issue
Block a user