mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[tensor] refactor chunk mgr and impl MemStatsCollectorV2 (#1077)
* polish chunk manager * polish unit test * impl add_extern_static_tensor for chunk mgr * add mem stats collector v2 * polish code * polish unit test * polish code * polish get chunks
This commit is contained in:
@@ -20,12 +20,13 @@ class ZeROHookV2(ParamOpHook):
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._chunk_manager.exec_lazy_release()
|
||||
# TODO: evict chunks
|
||||
for p in params:
|
||||
self._chunk_manager.access_chunk(p)
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
def post_op(self, params):
|
||||
for p in params:
|
||||
|
||||
@@ -48,7 +48,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if not self.module.chunk_manager.is_chunk_free(p):
|
||||
if not self.module.chunk_manager.get_chunk(p).is_free:
|
||||
p.data = self.fp16_param_to_fp32_param[p]
|
||||
else:
|
||||
assert p.grad is None
|
||||
|
||||
Reference in New Issue
Block a user