From 7db3ccc79b388d40fd93beb5908471e5910650ff Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 12 Apr 2022 13:55:25 +0800 Subject: [PATCH] [hotfix] remove duplicated param register to stateful tensor manager (#728) --- colossalai/zero/sharded_model/sharded_model_v2.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a324be8c5..028d0854c 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -106,11 +106,9 @@ class ShardedModelV2(nn.Module): GLOBAL_MODEL_DATA_TRACER.register_model(self) self._memstats_collector = MemStatsCollector() self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) - # for param in module.parameters(): - for submodule in module.modules(): - for param in submodule.parameters(recurse=False): - if hasattr(param, 'colo_attr'): - self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) + for param in module.parameters(): + if hasattr(param, 'colo_attr'): + self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: