mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-09 04:50:17 +00:00
[refactor] refactor the memory utils (#715)
This commit is contained in:
@@ -62,10 +62,9 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
||||
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
|
||||
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(
|
||||
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = MoeModel()
|
||||
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
|
Reference in New Issue
Block a user