mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-12 14:48:31 +00:00
[hotfix] remove duplicated param register to stateful tensor manager (#728)
This commit is contained in:
parent
600e769a42
commit
7db3ccc79b
@ -106,9 +106,7 @@ class ShardedModelV2(nn.Module):
|
|||||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||||
self._memstats_collector = MemStatsCollector()
|
self._memstats_collector = MemStatsCollector()
|
||||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
||||||
# for param in module.parameters():
|
for param in module.parameters():
|
||||||
for submodule in module.modules():
|
|
||||||
for param in submodule.parameters(recurse=False):
|
|
||||||
if hasattr(param, 'colo_attr'):
|
if hasattr(param, 'colo_attr'):
|
||||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||||
|
Loading…
Reference in New Issue
Block a user