mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-10-29 04:40:36 +00:00
polish code
This commit is contained in:
@@ -9,8 +9,7 @@ from colossalai.nn.optimizer import CPUAdam
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy,
|
||||
TensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
@@ -41,10 +40,10 @@ def _run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("use_cpuadam", [True, False])
|
||||
@parameterize("shard_strategy", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy, use_cpuadam):
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||
shard_strategy = shard_strategy()
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
||||
if use_cpuadam and cpu_offload is False:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user