From e396bb71f2d557a44566fd7ec958475f5d406b8e Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 13 Apr 2022 15:00:48 +0800 Subject: [PATCH] [zero] add tensor placement policies (#743) * add tensor placement policies * polish comments * polish comments * update moe unit tests --- .../zero/sharded_model/sharded_model_v2.py | 25 +++-- .../zero/sharded_optim/sharded_optim_v2.py | 18 ++-- colossalai/zero/utils/stateful_tensor_mgr.py | 55 ++--------- .../zero/utils/tensor_placement_policy.py | 94 +++++++++++++++++++ tests/test_moe/test_moe_zero_model.py | 2 +- tests/test_moe/test_moe_zero_optim.py | 4 +- tests/test_zero/common.py | 6 +- tests/test_zero/test_found_inf.py | 8 +- tests/test_zero/test_shard_model_v2.py | 2 +- tests/test_zero/test_sharded_optim_v2.py | 4 +- tests/test_zero/test_stateful_tensor_mgr.py | 4 +- 11 files changed, 139 insertions(+), 83 deletions(-) create mode 100644 colossalai/zero/utils/tensor_placement_policy.py diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 87564f9b8..04608184a 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -23,6 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) @@ -48,6 +49,11 @@ class ShardedModelV2(nn.Module): Generally, it should be `None`, and it's the same as `process_group`. Defaults to None. reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25. fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False. + tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'. + If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used. + If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used. + If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well. + Defaults to 'cuda'. offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None. gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0. use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False. @@ -65,9 +71,8 @@ class ShardedModelV2(nn.Module): reduce_scatter_process_group: Optional[ProcessGroup] = None, reduce_scatter_bucket_size_mb: int = 25, fp32_reduce_scatter: bool = False, - offload_config: Optional[dict] = None, + tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, - use_memory_tracer: bool = False, reuse_fp16_shard: bool = False): super().__init__() self.logger = get_dist_logger() @@ -100,20 +105,22 @@ class ShardedModelV2(nn.Module): self.rank = dist.get_rank(self.process_group) self.shard_strategy = shard_strategy + assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}' # Init Memory Statistics Collector - self._use_memory_tracer = use_memory_tracer + self._use_memory_tracer = tensor_placement_policy == 'auto' if self._use_memory_tracer: GLOBAL_MODEL_DATA_TRACER.register_model(self) self._memstats_collector = MemStatsCollector() - self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) - for param in module.parameters(): - if hasattr(param, 'colo_attr'): - self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) self._start_collect_memstats = disposable(self._memstats_collector.start_collection) self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection) else: self._memstats_collector = None - self._stateful_tensor_mgr = None + self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy]( + mem_stats_collector=self._memstats_collector) + self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy) + for param in module.parameters(): + if hasattr(param, 'colo_attr'): + self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) # Register hooks self._ophook_list = [ @@ -124,7 +131,7 @@ class ShardedModelV2(nn.Module): self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.fp32_reduce_scatter = fp32_reduce_scatter - self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False + self._cpu_offload: bool = tensor_placement_policy != 'cuda' for param in module.parameters(): # Init `offload_grad` param.colo_attr.offload_grad = self._cpu_offload diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 087d2c1f5..acf3d5904 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -16,12 +16,12 @@ from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_m colo_tensor_mem_usage) from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) from torch import Tensor from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter from torch.optim import Optimizer +from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy class OptimState(Enum): @@ -57,10 +57,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the shard strategy provided by sharded model to shard param fp32 tensors. optimizer (Optimizer): An Optimizer instance. - cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False. gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward) which will be used when using hybrid CPU optimizer. Make sure `reuse_fp16_shard` is enabled in `ShardedModelV2`, if `gpu_margin_mem_ratio` > `0.0`. + This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto". Defaults to 0.0. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. @@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): def __init__(self, sharded_model: ShardedModelV2, optimizer: Optimizer, - cpu_offload: bool = False, gpu_margin_mem_ratio: float = 0.0, initial_scale: float = 2**32, min_scale: float = 1, @@ -95,18 +94,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer): super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model - if cpu_offload and not sharded_model.cpu_offload: - raise RuntimeError( - f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload" - ) + self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' # Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid # Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors, # and it must set `num_fp32_shards_per_param` correctly - self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( + self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr( optimizer, 'num_fp32_shards_per_param', 0) >= 2 - self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') + self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu') self.optim_state: OptimState = OptimState.UNSCALED self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL) @@ -123,7 +119,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # Store fp32 param shards self._register_master_weight() - + if self.gpu_margin_mem_ratio != 0.0 and isinstance(sharded_model._tensor_placement_policy, + AutoTensorPlacementPolicy): + self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"') self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0]) diff --git a/colossalai/zero/utils/stateful_tensor_mgr.py b/colossalai/zero/utils/stateful_tensor_mgr.py index 1674775a2..c06dcc4a3 100644 --- a/colossalai/zero/utils/stateful_tensor_mgr.py +++ b/colossalai/zero/utils/stateful_tensor_mgr.py @@ -5,10 +5,8 @@ from colossalai.utils.cuda import get_current_device from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from typing import Dict, List -from colossalai.utils.memory_tracer import MemStatsCollector +from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy +from typing import List from colossalai.logging import get_dist_logger @@ -20,13 +18,12 @@ class StatefulTensorMgr(object): https://arxiv.org/abs/2108.05818 """ - def __init__(self, mem_stats_collector: MemStatsCollector) -> None: + def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: + self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy self._stateful_tensor_list: List[StatefulTensor] = [] - self._mem_stats_collector = mem_stats_collector self._logger = get_dist_logger("StatefulTensorMgr") self._warmup = True - self._warmup_cuda_available_ratio = 0.2 self._compute_list: List[StatefulTensor] = [] self._compute_idx: int = -1 @@ -47,9 +44,8 @@ class StatefulTensorMgr(object): It contains non-model footprint of a DNN model. """ # find stateful tensor in state COMPUTE - move_to_cuda_tensor_list = [] cuda_demand = 0 - used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage + move_to_cuda_tensor_list = [] hold_cuda_tensor_list = [] for tensor in self._stateful_tensor_list: if tensor.state == TensorState.FREE: @@ -64,22 +60,11 @@ class StatefulTensorMgr(object): cuda_demand += colo_tensor_mem_usage(tensor.payload)[1] else: raise RuntimeError - cuda_capacity = colo_device_memory_capacity(get_current_device()) - - if self._warmup: - # We designate a part of CUDA memory for model data in warmup iterations. - max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio - else: - # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda') - - total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period - avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data - - if avail_cuda_model_data < cuda_demand: - # Move cuda_demand - avail_cuda_model_data volume of tensors - # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data - self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data) + self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, + cuda_demand=cuda_demand, + warmup=self._warmup, + compute_list=self._compute_list, + compute_idx=self._compute_idx) # move COMPUTE tensors to CUDA for t in move_to_cuda_tensor_list: colo_model_data_tensor_move_inline(t, get_current_device()) @@ -90,26 +75,6 @@ class StatefulTensorMgr(object): self._warmup = False self._compute_idx = -1 - def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data): - freed_cuda_model_data = 0 - to_free_tensor_list = hold_cuda_tensor_list - if not self._warmup: - next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list} - for i in range(len(self._compute_list) - 1, self._compute_idx, -1): - if self._compute_list[i] in next_compute_idx: - next_compute_idx[self._compute_list[i]] = i - next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) - to_free_tensor_list = [t for (t, idx) in next_compute_idx] - for t in to_free_tensor_list: - if freed_cuda_model_data > to_free_cuda_model_data: - break - freed_cuda_model_data += colo_tensor_mem_usage(t)[0] - colo_model_data_tensor_move_inline(t, torch.device('cpu')) - if freed_cuda_model_data < to_free_cuda_model_data: - raise RuntimeError( - f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" - ) - def _trans_state(self, trans_state_func, stateful_tensor, state): trans_state_func(state) if state == TensorState.COMPUTE: diff --git a/colossalai/zero/utils/tensor_placement_policy.py b/colossalai/zero/utils/tensor_placement_policy.py new file mode 100644 index 000000000..953fd956c --- /dev/null +++ b/colossalai/zero/utils/tensor_placement_policy.py @@ -0,0 +1,94 @@ +from typing import List, Optional, Dict +import torch +from colossalai.utils import get_current_device +from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.utils.memory_tracer import MemStatsCollector +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER + +__all__ = ['TENSOR_PLACEMENT_POLICIES'] + + +class TensorPlacementPolicy: + + def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + self.device: Optional[torch.device] = device + self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: + raise NotImplementedError + + +class CPUTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: + for t in hold_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, self.device) + + +class CUDATensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available' + super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector) + + def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None: + pass + + +class AutoTensorPlacementPolicy(TensorPlacementPolicy): + + def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None: + super().__init__(None, mem_stats_collector=mem_stats_collector) + self._warmup_non_model_data_ratio: float = 0.2 + + def evict_tensors(self, + hold_cuda_tensor_list: List[StatefulTensor], + cuda_demand: int = 0, + warmup: bool = True, + compute_list: List[StatefulTensor] = [], + compute_idx: int = 0, + **kwargs) -> None: + cuda_capacity = colo_device_memory_capacity(get_current_device()) + used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage + if warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = self.mem_stats_collector.max_non_model_data('cuda') + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + freed_cuda_model_data = 0 + to_free_tensor_list = hold_cuda_tensor_list + if not warmup: + next_compute_idx: Dict[StatefulTensor, int] = {t: len(compute_list) for t in hold_cuda_tensor_list} + for i in range(len(compute_list) - 1, compute_idx, -1): + if compute_list[i] in next_compute_idx: + next_compute_idx[compute_list[i]] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + to_free_tensor_list = [t for (t, idx) in next_compute_idx] + for t in to_free_tensor_list: + if freed_cuda_model_data > to_free_cuda_model_data: + break + freed_cuda_model_data += colo_tensor_mem_usage(t)[0] + colo_model_data_tensor_move_inline(t, torch.device('cpu')) + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) + + +TENSOR_PLACEMENT_POLICIES = { + 'cpu': CPUTensorPlacementPolicy, + 'cuda': CUDATensorPlacementPolicy, + 'auto': AutoTensorPlacementPolicy +} diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index 2e3b620cf..945f8ba3c 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -32,7 +32,7 @@ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy=shard_strategy, shard_param=True): zero_model = MoeModel(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) # check whether parameters are identical in ddp for name, p in zero_model.named_parameters(): diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index cb39b8d7b..b0562da7c 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -69,8 +69,7 @@ def _run_test_sharded_optim_v2(cpu_offload, zero_model = ShardedModelV2(zero_model, shard_strategy, - offload_config=dict(device='cpu') if cpu_offload else None, - use_memory_tracer=gpu_margin_mem_ratio > 0.0, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', reuse_fp16_shard=reuse_fp16_shard) # check whether parameters are identical in ddp @@ -88,7 +87,6 @@ def _run_test_sharded_optim_v2(cpu_offload, sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, - cpu_offload=cpu_offload, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio) diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py index d495cf018..5d2ff173f 100644 --- a/tests/test_zero/common.py +++ b/tests/test_zero/common.py @@ -13,14 +13,12 @@ MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(siz _ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, fp32_reduce_scatter=False, - offload_config=None, + tensor_placement_policy='cuda', gradient_predivide_factor=1.0, - use_memory_tracer=False, shard_strategy=TensorShardStrategy(), reuse_fp16_shard=False) -_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False, - initial_scale=2**5, +_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, min_scale=1, growth_factor=2, backoff_factor=0.5, diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 45bdd6e01..897038355 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -37,16 +37,12 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio) zero_model = ShardedModelV2( zero_model, shard_strategy, - offload_config=dict(device='cpu') if cpu_offload else None, - use_memory_tracer=gpu_margin_mem_ratio > 0.0, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', reuse_fp16_shard=True, ) sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - cpu_offload=cpu_offload, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) for i, (data, label) in enumerate(train_dataloader): if i > 1: diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 2d230f85f..1f46883a8 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -33,7 +33,7 @@ def run_model_test(enable_autocast, shard_strategy_class): shard_strategy=shard_strategy, shard_param=True): zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True) + zero_model = ShardedModelV2(zero_model, shard_strategy) model = model_builder(checkpoint=True).half() col_model_deepcopy(zero_model, model) diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 34287969f..2e94df7de 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -64,8 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g zero_model = ShardedModelV2( zero_model, shard_strategy, - offload_config=dict(device='cpu') if cpu_offload else None, - use_memory_tracer=gpu_margin_mem_ratio > 0.0, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', reuse_fp16_shard=use_cpuadam, ) @@ -79,7 +78,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, - cpu_offload=cpu_offload, initial_scale=2**5, gpu_margin_mem_ratio=gpu_margin_mem_ratio) diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index bc8475914..5b9e35a26 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -14,6 +14,7 @@ from colossalai.testing import rerun_on_exception from torch.nn.parameter import Parameter from typing import List from functools import partial +from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy class Net(torch.nn.Module): @@ -37,7 +38,8 @@ def run_stm(): p.colo_attr = ShardedParamV2(p, set_data_none=True) GLOBAL_MODEL_DATA_TRACER.register_model(model) mem_collector = MemStatsCollector() - stateful_tensor_mgr = StatefulTensorMgr(mem_collector) + tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector) + stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy) for p in model.parameters(): stateful_tensor_mgr.register_stateful_param(p.colo_attr)