polish code

This commit is contained in:
ver217
2022-03-18 13:57:20 +08:00
parent e99af94ab8
commit 8cf7ff08cf
4 changed files with 24 additions and 28 deletions

View File

@@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam
from colossalai.testing import parameterize
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy,
TensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from colossalai.zero.sharded_optim import ShardedOptimizerV2
@@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
@parameterize("cpu_offload", [True, False])
@parameterize("use_cpuadam", [True, False])
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = shard_strategy()
shard_strategy = shard_strategy_class()
if use_cpuadam and cpu_offload is False:
return