[gemini] accelerate adjust_layout() (#878)

* add lru cache

* polish code

* update unit test

* fix sharded optim
This commit is contained in:
ver217
2022-04-26 18:08:31 +08:00
committed by GitHub
parent 909211453b
commit c4d903e64a
5 changed files with 62 additions and 36 deletions

View File

@@ -45,8 +45,8 @@ def run_stm():
mem_collector = MemStatsCollector()
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
for p in model.parameters():
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
stateful_tensors = [p.colo_attr.sharded_data_tensor for p in model.parameters()]
stateful_tensor_mgr.register_stateful_tensor_list(stateful_tensors)
mem_collector.start_collection()
# Compute order: 0 1 2 0 1
@@ -67,7 +67,7 @@ def run_stm():
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_model_data()
mem_collector.finish_collection()
stateful_tensor_mgr.reset()
stateful_tensor_mgr.finish_iter()
# warmup done
# only 2 params can be on CUDA