[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)

* polish chunk manager

* polish unit test

* impl add_extern_static_tensor for chunk mgr

* add mem stats collector v2

* polish code

* polish unit test

* polish code

* polish get chunks
This commit is contained in:
ver217
2022-06-09 20:56:34 +08:00
committed by GitHub
parent b3a03e4bfd
commit be01db37c8
6 changed files with 68 additions and 31 deletions

View File

@@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params):
chunks = self._chunk_manager.get_chunks(params)
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._chunk_manager.exec_lazy_release()
# TODO: evict chunks
for p in params:
self._chunk_manager.access_chunk(p)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
def post_op(self, params):
for p in params:

View File

@@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
def _update_params_ptr(self):
for group in self.optim.param_groups:
for p in group['params']:
if not self.module.chunk_manager.is_chunk_free(p):
if not self.module.chunk_manager.get_chunk(p).is_free:
p.data = self.fp16_param_to_fp32_param[p]
else:
assert p.grad is None