diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index eeb8117a2..b0ab82a94 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -7,6 +7,7 @@ from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param.tensorful_state import TensorState +from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr from ._base_ophook import BaseOpHook @@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook): def __init__(self, shard_strategy: BaseShardStrategy, - memstarts_collector: Optional[MemStatsCollector], + memstarts_collector: Optional[MemStatsCollector] = None, + stateful_tensor_mgr: Optional[StatefulTensorMgr] = None, process_group: Optional[dist.ProcessGroup] = None): super().__init__() self.shard_strategy = shard_strategy self.process_group = process_group + # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU self.computing_device = torch.device(f'cuda:{get_current_device()}') self._memstarts_collector = memstarts_collector + self._stateful_tensor_mgr = stateful_tensor_mgr def pre_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + + if self._stateful_tensor_mgr: + self._stateful_tensor_mgr.adjust_layout() + else: + for param in module.parameters(recurse=False): + colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) + tensor_list = [] for param in module.parameters(recurse=False): assert hasattr(param, 'colo_attr') tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) - for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - param.data = param.colo_attr.sharded_data_tensor.payload if self._memstarts_collector: self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + param.data = param.colo_attr.sharded_data_tensor.payload + assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" def post_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(recurse=False): @@ -60,19 +71,27 @@ class ZeroHook(BaseOpHook): param.colo_attr.remove_torch_payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + + if self._stateful_tensor_mgr: + self._stateful_tensor_mgr.adjust_layout() + else: + for param in module.parameters(recurse=False): + colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) + tensor_list = [] for param in module.parameters(recurse=False): assert hasattr(param, 'colo_attr') tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) - for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - param.data = param.colo_attr.sharded_data_tensor.payload + if self._memstarts_collector: self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + param.data = param.colo_attr.sharded_data_tensor.payload + assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" def post_bwd_exec(self, module: torch.nn.Module, input): for param in module.parameters(recurse=False): @@ -91,4 +110,5 @@ class ZeroHook(BaseOpHook): pass def post_iter(self): - pass + if self._stateful_tensor_mgr: + self._stateful_tensor_mgr.reset() diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 054d3f282..72c41b470 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -20,10 +20,12 @@ class SamplingCounter: assert self._max_sampling_cnt is not None return (self._samplint_cnt + 1) % self._max_sampling_cnt - @property - def sampling_cnt(self): + def current(self): return self._samplint_cnt + def max(self): + return self._max_sampling_cnt + def reset(self): self._max_sampling_cnt = self._samplint_cnt self._samplint_cnt = 0 @@ -37,7 +39,7 @@ class MemStatsCollector: The first iteration of DNN training. Phase 2. Runtime Phase: use the read-only collected stats The rest iterations of DNN training. - + It has a Sampling counter which is reset after DNN training iteration. """ @@ -50,6 +52,8 @@ class MemStatsCollector: self._model_data_cpu_list = [] self._overall_cpu_list = [] + self._non_model_data_cuda_list = [] + self._non_model_data_cpu_list = [] self._sampling_time = [] self._start_flag = False @@ -96,18 +100,20 @@ class MemStatsCollector: raise TypeError if device_type == 'cuda': - return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)] + return [elem / scale for elem in self._non_model_data_cuda_list] elif device_type == 'cpu': - return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)] + return [elem / scale for elem in self._non_model_data_cpu_list] else: raise TypeError def current_non_model_data(self, device_type: str) -> int: - """get the non model data of current sampling moment + """get the non model data of the current sampling moment """ - return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt] + return self.non_model_data_list(device_type)[self._sampling_cnter.current()] def next_non_model_data(self, device_type: str): + """get the non model data of the next sampling moment + """ return self.non_model_data_list(device_type)[self._sampling_cnter.next()] @property @@ -128,18 +134,20 @@ class MemStatsCollector: Advance the sampling cnter. """ if self._start_flag: - sampling_cnt = self._sampling_cnter.sampling_cnt + sampling_cnt = self._sampling_cnter.current() assert sampling_cnt == len(self._overall_cuda_list) self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._overall_cuda_list.append(self._mem_monitor.finish()) + self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1]) self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage) - - # FIXME() cpu sys used should also return from self._mem_monitor() + # FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor() self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu'))) - + self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) self._sampling_time.append(time.time()) self._mem_monitor.start() + # TODO(ver217): refactor sampler + # print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}') self._sampling_cnter.advance() def reset_sampling_cnter(self) -> None: @@ -155,4 +163,4 @@ class MemStatsCollector: self._start_flag = False self._sampling_cnter.reset() - self._mem_monitor.finish() \ No newline at end of file + self._mem_monitor.finish() diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py index 5e5d63a7e..9a7917c63 100644 --- a/colossalai/zero/shard_utils/__init__.py +++ b/colossalai/zero/shard_utils/__init__.py @@ -1,5 +1,6 @@ from .base_shard_strategy import BaseShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy +from .stateful_tensor_mgr import StatefulTensorMgr -__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] +__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr'] diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/shard_utils/stateful_tensor_mgr.py index 8daefeb2f..3a14f5139 100644 --- a/colossalai/zero/shard_utils/stateful_tensor_mgr.py +++ b/colossalai/zero/shard_utils/stateful_tensor_mgr.py @@ -1,26 +1,43 @@ +import functools import torch -from colossalai.context.singleton_meta import SingletonMeta +import types from colossalai.utils.cuda import get_current_device from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity -from typing import Set +from typing import Dict, List from colossalai.utils.memory_tracer import MemStatsCollector +from colossalai.logging import get_dist_logger -class StatefulTensorMgr(SingletonMeta): - _stateful_tensor_list: Set[ShardedParamV2] = set() +class StatefulTensorMgr(object): + """ + Stateful Tensor Manager, inspired from PatrickStar - def register_param(self, param: ShardedParamV2) -> None: + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, mem_stats_collector: MemStatsCollector) -> None: + self._stateful_tensor_list: List[StatefulTensor] = [] + self._mem_stats_collector = mem_stats_collector + self._logger = get_dist_logger("StatefulTensorMgr") + + self._warmup = True + self._warmup_cuda_available_ratio = 0.2 + + self._compute_list: List[StatefulTensor] = [] + self._compute_idx: int = -1 + + def register_stateful_param(self, param: ShardedParamV2) -> None: + assert isinstance(param, ShardedParamV2) for t in param.get_payload_tensors(): assert isinstance(t, StatefulTensor) - self._stateful_tensor_list.add(t) + self._stateful_tensor_list.append(t) + t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) - def evict_tensors(self) -> None: - pass - - def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None: + def adjust_layout(self) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. @@ -41,29 +58,62 @@ class StatefulTensorMgr(SingletonMeta): used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0] if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) - else: + elif tensor.device.type == 'cpu': if tensor.state == TensorState.COMPUTE: move_to_cuda_tensor_list.append(tensor) - cuda_demand += colo_tensor_mem_usage(tensor.payload)[0] - - # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'), - mem_stats_collector.next_non_model_data('cuda')) + cuda_demand += colo_tensor_mem_usage(tensor.payload)[1] + else: + raise RuntimeError cuda_capacity = colo_cuda_memory_capacity() - cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period - if cuda_model_data_period < used_cuda_model_data + cuda_demand: - # move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor - # Here use a naive eviction strategy. - acc_size = 0 - for t in hold_cuda_tensor_list: - if acc_size > cuda_demand: - break - colo_model_data_tensor_move_inline(t, torch.device('cpu')) - t_size = colo_tensor_mem_usage(t) - acc_size += t_size - if acc_size < cuda_demand: - raise RuntimeError("Adjust layout failed! No enough CUDA memory!") + if self._warmup: + # We designate a part of CUDA memory for model data in warmup iterations. + max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio + else: + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'), + self._mem_stats_collector.next_non_model_data('cuda')) + + total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period + avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data + + if avail_cuda_model_data < cuda_demand: + # Move cuda_demand - avail_cuda_model_data volume of tensors + # to_free_cuda_model_data = cuda_demand - avail_cuda_model_data + self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data) # move COMPUTE tensors to CUDA for t in move_to_cuda_tensor_list: colo_model_data_tensor_move_inline(t, get_current_device()) + + def reset(self): + """This function must be called when each iteration finishes + """ + self._warmup = False + self._compute_idx = -1 + + def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data): + freed_cuda_model_data = 0 + to_free_tensor_list = hold_cuda_tensor_list + if not self._warmup: + next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list} + for i in range(len(self._compute_list) - 1, self._compute_idx, -1): + if self._compute_list[i] in next_compute_idx: + next_compute_idx[self._compute_list[i]] = i + next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True) + to_free_tensor_list = [t for (t, idx) in next_compute_idx] + for t in to_free_tensor_list: + if freed_cuda_model_data > to_free_cuda_model_data: + break + freed_cuda_model_data += colo_tensor_mem_usage(t)[0] + colo_model_data_tensor_move_inline(t, torch.device('cpu')) + if freed_cuda_model_data < to_free_cuda_model_data: + raise RuntimeError( + f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}" + ) + + def _trans_state(self, trans_state_func, stateful_tensor, state): + trans_state_func(state) + if state == TensorState.COMPUTE: + self._compute_idx += 1 + if self._warmup: + self._compute_list.append(stateful_tensor) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 32779ad89..9f05eb363 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -23,6 +23,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_param.tensorful_state import TensorState from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter +from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) @@ -36,7 +37,6 @@ class ShardedModelV2(nn.Module): Note: You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. - Note: Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, if you enable ``reuse_fp16_shard``. @@ -106,12 +106,21 @@ class ShardedModelV2(nn.Module): if self._use_memory_tracer: GLOBAL_MODEL_DATA_TRACER.register_model(self) self._memstats_collector = MemStatsCollector() + self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector) + # for param in module.parameters(): + for submodule in module.modules(): + for param in submodule.parameters(recurse=False): + if hasattr(param, 'colo_attr'): + self._stateful_tensor_mgr.register_stateful_param(param.colo_attr) else: self._memstats_collector = None + self._stateful_tensor_mgr = None self._iter_cnter = 0 # Register hooks - self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)] + self._ophook_list = [ + ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group) + ] register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded) self.param_hook_mgr = BaseParamHookMgr(self.sharded_params) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) @@ -138,6 +147,9 @@ class ShardedModelV2(nn.Module): self._cuda_margin_space = 0 self.reuse_fp16_shard = reuse_fp16_shard + def adjust_stateful_tensor_layout(self) -> None: + self._stateful_tensor_mgr.adjust_layout() + @property def use_memory_tracer(self): return self._use_memory_tracer @@ -150,20 +162,15 @@ class ShardedModelV2(nn.Module): def cpu_offload(self): return self._cpu_offload - def dump_memory_stats(self, filename: str = 'dump_mem_stats.log') -> None: - """Dummy memory tracer collected infomation to a file. - - Example:: - - try: - # forward: model(inputs) - # backward: optimizer.backward() - except Exception as e: - model.dump_memory_stats() - exit(0) - - Args: - filename (str, optional): Output file name. Defaults to 'dump_mem_stats.log'. + def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None: + """ + dummy memory tracer collected infomation to a file. + try: + # forward: model(inputs) + # backward: optimizer.backward() + except Exception as e: + model.dump_memory_stats() + exit(0) """ if self._use_memory_tracer: self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) @@ -172,12 +179,12 @@ class ShardedModelV2(nn.Module): f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n') f.write('CUDA model data (GB)\n') - f.write(str(self._memstats_collector.model_data_cuda_list('cuda', 'GB'))) + f.write(str(self._memstats_collector.model_data_list('cuda', 'GB'))) f.write('\n') f.write('CUDA non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_cuda_list('cuda', 'GB'))) + f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB'))) f.write('CPU non model data (GB)\n') - f.write(str(self._memstats_collector.non_model_data_cuda_list('cpu', 'GB'))) + f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB'))) f.write('\n') def _pre_forward_operations(self): diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index b9252eb94..31f58b9e0 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -350,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # TODO() optimize this line CPU (fp32) -> GPU (fp16) p.colo_attr.sharded_data_tensor.reset_payload( - colo_model_tensor_clone(p.half(), torch.cuda.current_device())) + colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device)) if not is_param_sharded and not self.keep_unshard: # We gather full fp16 param here diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 277eab380..92f5bb59c 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -26,7 +26,7 @@ class ShardedParamV2(object): def get_payload_tensors(self) -> List[StatefulTensor]: """returns stateful tensors kept by this class. """ - return [self._sharded_data_tensor, self.saved_grad] + return [self._sharded_data_tensor] def remove_torch_payload(self): self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) diff --git a/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py new file mode 100644 index 000000000..b77d02e94 --- /dev/null +++ b/tests/test_zero_data_parallel/test_stateful_tensor_mgr.py @@ -0,0 +1,112 @@ +import torch +import colossalai +import pytest +import torch.multiprocessing as mp +from colossalai.utils.memory_tracer import MemStatsCollector +from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER +from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction +from colossalai.zero.shard_utils import StatefulTensorMgr +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.sharded_param.tensorful_state import TensorState +from colossalai.utils import free_port +from colossalai.testing import rerun_on_exception +from torch.nn.parameter import Parameter +from typing import List +from functools import partial + + +class Net(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + # each parameter is 512 MB + self.p0 = Parameter(torch.empty(1024, 1024, 128)) + self.p1 = Parameter(torch.empty(1024, 1024, 128)) + self.p2 = Parameter(torch.empty(1024, 1024, 128)) + + +def run_stm(): + cuda_capacity = colo_cuda_memory_capacity() + fraction = (1.4 * 1024**3) / cuda_capacity + # limit max memory to 1.4GB + # which means only 2 parameters can be on CUDA + colo_set_process_memory_fraction(fraction) + model = Net() + for p in model.parameters(): + p.colo_attr = ShardedParamV2(p, rm_torch_payload=True) + GLOBAL_MODEL_DATA_TRACER.register_model(model) + mem_collector = MemStatsCollector() + stateful_tensor_mgr = StatefulTensorMgr(mem_collector) + for p in model.parameters(): + stateful_tensor_mgr.register_stateful_param(p.colo_attr) + + mem_collector.start_collection() + # Compute order: 0 1 2 0 1 + # warmup + # use naive eviction strategy + apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + mem_collector.finish_collection() + mem_collector.reset_sampling_cnter() + stateful_tensor_mgr.reset() + + # warmup done + # use OPT-like eviction strategy + apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) + mem_collector.sample_memstats() + + +def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter], + stateful_tensor_mgr: StatefulTensorMgr): + compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE) + for p in model.parameters(): + if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD: + p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD) + stateful_tensor_mgr.adjust_layout() + print_stats(model) + device = torch.device(torch.cuda.current_device()) + cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust] + for n, p in model.named_parameters(): + if hash(p) in cuda_param_after_adjust: + assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}' + else: + assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu') + + +def print_stats(model: torch.nn.Module): + msgs = [] + for n, p in model.named_parameters(): + msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})') + print(f'[ {", ".join(msgs)} ]') + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_stm() + + +@pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_stateful_tensor_manager(world_size=1): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_stateful_tensor_manager()