mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 01:06:00 +00:00
[zero] get memory usage of sharded optim v2. (#542)
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -57,11 +58,12 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ZeroInitContext(convert_fp16=True,
|
||||
target_device=torch.device(f'cpu:0'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
with ZeroInitContext(
|
||||
convert_fp16=True,
|
||||
target_device=torch.device(f'cpu:0') if cpu_offload else torch.device(f'cuda:{get_current_device()}'),
|
||||
shard_strategy=shard_strategy,
|
||||
shard_param=True,
|
||||
rm_torch_payload_on_the_fly=False):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
|
Reference in New Issue
Block a user