[refactor] refactor the memory utils (#715)

This commit is contained in:
Jiarui Fang
2022-04-11 16:47:57 +08:00
committed by GitHub
parent dbd96fe90a
commit 193dc8dacb
20 changed files with 218 additions and 308 deletions

View File

@@ -62,10 +62,9 @@ def _run_test_sharded_optim_v2(cpu_offload,
get_components_func = non_distributed_component_funcs.get_callable('no_leaf_module')
_, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(
target_device=torch.device('cpu') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
shard_strategy=shard_strategy,
shard_param=True):
with ZeroInitContext(target_device=torch.device('cpu') if cpu_offload else get_current_device(),
shard_strategy=shard_strategy,
shard_param=True):
zero_model = MoeModel()
zero_model = ShardedModelV2(zero_model,