[zero] add ZeroTensorShardStrategy (#793)

This commit is contained in:
HELSON
2022-04-19 14:32:45 +08:00
committed by GitHub
parent 681addb512
commit 88759e289e
12 changed files with 214 additions and 11 deletions

View File

@@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
@@ -20,7 +20,7 @@ from common import CONFIG
@parameterize("cpu_offload", [True, False])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy])
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
test_models = ['repeated_computed_layers']

View File

@@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
colo_model_mem_usage
from colossalai.utils.memory import colo_device_memory_used
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("init_device_type", ['cpu', 'cuda'])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
def run_model_test(init_device_type, shard_strategy_class):
logger = get_dist_logger("test_zero_init")

View File

@@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.shard_utils import BucketTensorShardStrategy
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from functools import partial
@@ -35,7 +35,7 @@ def run_mem_collector_testing():
fraction = (50 * 1024**2) / cuda_capacity
# limit max memory to 50MB
colo_set_process_memory_fraction(fraction)
shard_strategy = BucketTensorShardStrategy()
shard_strategy = ZeroTensorShardStrategy()
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
model = MyTestModel()

View File

@@ -10,7 +10,7 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.utils import col_model_deepcopy
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
@parameterize("enable_autocast", [True])
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy])
def run_model_test(enable_autocast, shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
shard_strategy = shard_strategy_class()

View File

@@ -11,7 +11,7 @@ import torch.multiprocessing as mp
from colossalai.testing import parameterize, rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model.utils import col_model_deepcopy
from tests.components_to_test.registry import non_distributed_component_funcs
@@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
from common import CONFIG
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
def run_zero_state_dict(shard_strategy_class):
test_models = ['repeated_computed_layers', 'resnet18']
shard_strategy = shard_strategy_class()