mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[gemini] accelerate adjust_layout() (#878)
* add lru cache * polish code * update unit test * fix sharded optim
This commit is contained in:
@@ -45,8 +45,8 @@ def run_stm():
|
||||
mem_collector = MemStatsCollector()
|
||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
stateful_tensors = [p.colo_attr.sharded_data_tensor for p in model.parameters()]
|
||||
stateful_tensor_mgr.register_stateful_tensor_list(stateful_tensors)
|
||||
|
||||
mem_collector.start_collection()
|
||||
# Compute order: 0 1 2 0 1
|
||||
@@ -67,7 +67,7 @@ def run_stm():
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_model_data()
|
||||
mem_collector.finish_collection()
|
||||
stateful_tensor_mgr.reset()
|
||||
stateful_tensor_mgr.finish_iter()
|
||||
|
||||
# warmup done
|
||||
# only 2 params can be on CUDA
|
||||
|
Reference in New Issue
Block a user