mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[zero] add ZeroTensorShardStrategy (#793)
This commit is contained in:
@@ -9,7 +9,7 @@ from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
@@ -20,7 +20,7 @@ from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("cpu_offload", [True, False])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy])
|
||||
@parameterize("gpu_margin_mem_ratio", [0.0, 0.7])
|
||||
def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio):
|
||||
test_models = ['repeated_computed_layers']
|
||||
|
@@ -15,14 +15,14 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||
colo_model_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("init_device_type", ['cpu', 'cuda'])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
||||
def run_model_test(init_device_type, shard_strategy_class):
|
||||
logger = get_dist_logger("test_zero_init")
|
||||
|
||||
|
@@ -8,7 +8,7 @@ from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.shard_utils import BucketTensorShardStrategy
|
||||
from colossalai.zero.shard_utils import ZeroTensorShardStrategy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from functools import partial
|
||||
@@ -35,7 +35,7 @@ def run_mem_collector_testing():
|
||||
fraction = (50 * 1024**2) / cuda_capacity
|
||||
# limit max memory to 50MB
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
shard_strategy = BucketTensorShardStrategy()
|
||||
shard_strategy = ZeroTensorShardStrategy()
|
||||
with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True):
|
||||
model = MyTestModel()
|
||||
|
||||
|
@@ -10,7 +10,7 @@ import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
@@ -21,7 +21,7 @@ from common import CONFIG, check_grads_padding, run_fwd_bwd
|
||||
|
||||
|
||||
@parameterize("enable_autocast", [True])
|
||||
@parameterize("shard_strategy_class", [BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [ZeroTensorShardStrategy, BucketTensorShardStrategy])
|
||||
def run_model_test(enable_autocast, shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'no_leaf_module']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
@@ -11,7 +11,7 @@ import torch.multiprocessing as mp
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.init_ctx import ZeroInitContext
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy, ZeroTensorShardStrategy)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model.utils import col_model_deepcopy
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
@@ -19,7 +19,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from common import CONFIG
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy, ZeroTensorShardStrategy])
|
||||
def run_zero_state_dict(shard_strategy_class):
|
||||
test_models = ['repeated_computed_layers', 'resnet18']
|
||||
shard_strategy = shard_strategy_class()
|
||||
|
Reference in New Issue
Block a user