[zero] stateful tensor manager (#687)

* [WIP] stateful tensor manager

* add eviction strategy

* polish code

* polish code

* polish comment

* add unit test

* fix sampler bug

* polish code

* fix max sampling cnt resetting bug

* fix sampler bug

* polish code

* fix bug

* fix unit test

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
ver217
2022-04-08 17:51:34 +08:00
committed by GitHub
parent 70e8dd418b
commit 3c9cd5bb5e
8 changed files with 271 additions and 73 deletions

View File

@@ -7,6 +7,7 @@ from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
from ._base_ophook import BaseOpHook from ._base_ophook import BaseOpHook
@@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook):
def __init__(self, def __init__(self,
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector], memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None): process_group: Optional[dist.ProcessGroup] = None):
super().__init__() super().__init__()
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.process_group = process_group self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU # NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}') self.computing_device = torch.device(f'cuda:{get_current_device()}')
self._memstarts_collector = memstarts_collector self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) param.data = param.colo_attr.sharded_data_tensor.payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
@@ -60,19 +71,27 @@ class ZeroHook(BaseOpHook):
param.colo_attr.remove_torch_payload() param.colo_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.adjust_layout()
else:
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) param.data = param.colo_attr.sharded_data_tensor.payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
@@ -91,4 +110,5 @@ class ZeroHook(BaseOpHook):
pass pass
def post_iter(self): def post_iter(self):
pass if self._stateful_tensor_mgr:
self._stateful_tensor_mgr.reset()

View File

@@ -20,10 +20,12 @@ class SamplingCounter:
assert self._max_sampling_cnt is not None assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt return (self._samplint_cnt + 1) % self._max_sampling_cnt
@property def current(self):
def sampling_cnt(self):
return self._samplint_cnt return self._samplint_cnt
def max(self):
return self._max_sampling_cnt
def reset(self): def reset(self):
self._max_sampling_cnt = self._samplint_cnt self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0 self._samplint_cnt = 0
@@ -50,6 +52,8 @@ class MemStatsCollector:
self._model_data_cpu_list = [] self._model_data_cpu_list = []
self._overall_cpu_list = [] self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = [] self._sampling_time = []
self._start_flag = False self._start_flag = False
@@ -96,18 +100,20 @@ class MemStatsCollector:
raise TypeError raise TypeError
if device_type == 'cuda': if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)] return [elem / scale for elem in self._non_model_data_cuda_list]
elif device_type == 'cpu': elif device_type == 'cpu':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)] return [elem / scale for elem in self._non_model_data_cpu_list]
else: else:
raise TypeError raise TypeError
def current_non_model_data(self, device_type: str) -> int: def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of current sampling moment """get the non model data of the current sampling moment
""" """
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt] return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
def next_non_model_data(self, device_type: str): def next_non_model_data(self, device_type: str):
"""get the non model data of the next sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.next()] return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
@property @property
@@ -128,18 +134,20 @@ class MemStatsCollector:
Advance the sampling cnter. Advance the sampling cnter.
""" """
if self._start_flag: if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt sampling_cnt = self._sampling_cnter.current()
assert sampling_cnt == len(self._overall_cuda_list) assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish()) self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage) self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
# FIXME() cpu sys used should also return from self._mem_monitor()
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu'))) self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time()) self._sampling_time.append(time.time())
self._mem_monitor.start() self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self._sampling_cnter.advance() self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None: def reset_sampling_cnter(self) -> None:

View File

@@ -1,5 +1,6 @@
from .base_shard_strategy import BaseShardStrategy from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
from .stateful_tensor_mgr import StatefulTensorMgr
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy'] __all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']

View File

@@ -1,26 +1,43 @@
import functools
import torch import torch
from colossalai.context.singleton_meta import SingletonMeta import types
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from typing import Set from typing import Dict, List
from colossalai.utils.memory_tracer import MemStatsCollector from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.logging import get_dist_logger
class StatefulTensorMgr(SingletonMeta): class StatefulTensorMgr(object):
_stateful_tensor_list: Set[ShardedParamV2] = set() """
Stateful Tensor Manager, inspired from PatrickStar
def register_param(self, param: ShardedParamV2) -> None: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
self._stateful_tensor_list: List[StatefulTensor] = []
self._mem_stats_collector = mem_stats_collector
self._logger = get_dist_logger("StatefulTensorMgr")
self._warmup = True
self._warmup_cuda_available_ratio = 0.2
self._compute_list: List[StatefulTensor] = []
self._compute_idx: int = -1
def register_stateful_param(self, param: ShardedParamV2) -> None:
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors(): for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor) assert isinstance(t, StatefulTensor)
self._stateful_tensor_list.add(t) self._stateful_tensor_list.append(t)
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
def evict_tensors(self) -> None: def adjust_layout(self) -> None:
pass
def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
""" Adjust the layout of statefuil tensor according to the information provided """ Adjust the layout of statefuil tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model. by mem_stats_collector, which should belongs to a Sharded Model.
@@ -41,29 +58,62 @@ class StatefulTensorMgr(SingletonMeta):
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0] used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor) hold_cuda_tensor_list.append(tensor)
else: elif tensor.device.type == 'cpu':
if tensor.state == TensorState.COMPUTE: if tensor.state == TensorState.COMPUTE:
move_to_cuda_tensor_list.append(tensor) move_to_cuda_tensor_list.append(tensor)
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0] cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. raise RuntimeError
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
mem_stats_collector.next_non_model_data('cuda'))
cuda_capacity = colo_cuda_memory_capacity() cuda_capacity = colo_cuda_memory_capacity()
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
# Here use a naive eviction strategy.
acc_size = 0
for t in hold_cuda_tensor_list:
if acc_size > cuda_demand:
break
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
t_size = colo_tensor_mem_usage(t)
acc_size += t_size
if acc_size < cuda_demand:
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")
if self._warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
self._mem_stats_collector.next_non_model_data('cuda'))
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
# move COMPUTE tensors to CUDA # move COMPUTE tensors to CUDA
for t in move_to_cuda_tensor_list: for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device()) colo_model_data_tensor_move_inline(t, get_current_device())
def reset(self):
"""This function must be called when each iteration finishes
"""
self._warmup = False
self._compute_idx = -1
def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
freed_cuda_model_data = 0
to_free_tensor_list = hold_cuda_tensor_list
if not self._warmup:
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
if self._compute_list[i] in next_compute_idx:
next_compute_idx[self._compute_list[i]] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
for t in to_free_tensor_list:
if freed_cuda_model_data > to_free_cuda_model_data:
break
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
def _trans_state(self, trans_state_func, stateful_tensor, state):
trans_state_func(state)
if state == TensorState.COMPUTE:
self._compute_idx += 1
if self._warmup:
self._compute_list.append(stateful_tensor)

View File

@@ -23,6 +23,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
@@ -36,7 +37,6 @@ class ShardedModelV2(nn.Module):
Note: Note:
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``. You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
Note: Note:
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter, Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
if you enable ``reuse_fp16_shard``. if you enable ``reuse_fp16_shard``.
@@ -106,12 +106,21 @@ class ShardedModelV2(nn.Module):
if self._use_memory_tracer: if self._use_memory_tracer:
GLOBAL_MODEL_DATA_TRACER.register_model(self) GLOBAL_MODEL_DATA_TRACER.register_model(self)
self._memstats_collector = MemStatsCollector() self._memstats_collector = MemStatsCollector()
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
# for param in module.parameters():
for submodule in module.modules():
for param in submodule.parameters(recurse=False):
if hasattr(param, 'colo_attr'):
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
else: else:
self._memstats_collector = None self._memstats_collector = None
self._stateful_tensor_mgr = None
self._iter_cnter = 0 self._iter_cnter = 0
# Register hooks # Register hooks
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)] self._ophook_list = [
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
]
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded) register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params) self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook) self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
@@ -138,6 +147,9 @@ class ShardedModelV2(nn.Module):
self._cuda_margin_space = 0 self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard self.reuse_fp16_shard = reuse_fp16_shard
def adjust_stateful_tensor_layout(self) -> None:
self._stateful_tensor_mgr.adjust_layout()
@property @property
def use_memory_tracer(self): def use_memory_tracer(self):
return self._use_memory_tracer return self._use_memory_tracer
@@ -150,20 +162,15 @@ class ShardedModelV2(nn.Module):
def cpu_offload(self): def cpu_offload(self):
return self._cpu_offload return self._cpu_offload
def dump_memory_stats(self, filename: str = 'dump_mem_stats.log') -> None: def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
"""Dummy memory tracer collected infomation to a file. """
dummy memory tracer collected infomation to a file.
Example:: try:
# forward: model(inputs)
try: # backward: optimizer.backward()
# forward: model(inputs) except Exception as e:
# backward: optimizer.backward() model.dump_memory_stats()
except Exception as e: exit(0)
model.dump_memory_stats()
exit(0)
Args:
filename (str, optional): Output file name. Defaults to 'dump_mem_stats.log'.
""" """
if self._use_memory_tracer: if self._use_memory_tracer:
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0]) self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
@@ -172,12 +179,12 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n') f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n') f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
f.write('CUDA model data (GB)\n') f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector.model_data_cuda_list('cuda', 'GB'))) f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
f.write('\n') f.write('\n')
f.write('CUDA non model data (GB)\n') f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_cuda_list('cuda', 'GB'))) f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
f.write('CPU non model data (GB)\n') f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector.non_model_data_cuda_list('cpu', 'GB'))) f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
f.write('\n') f.write('\n')
def _pre_forward_operations(self): def _pre_forward_operations(self):

View File

@@ -350,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO() optimize this line CPU (fp32) -> GPU (fp16) # TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload( p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device())) colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
if not is_param_sharded and not self.keep_unshard: if not is_param_sharded and not self.keep_unshard:
# We gather full fp16 param here # We gather full fp16 param here

View File

@@ -26,7 +26,7 @@ class ShardedParamV2(object):
def get_payload_tensors(self) -> List[StatefulTensor]: def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class. """returns stateful tensors kept by this class.
""" """
return [self._sharded_data_tensor, self.saved_grad] return [self._sharded_data_tensor]
def remove_torch_payload(self): def remove_torch_payload(self):
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)

View File

@@ -0,0 +1,112 @@
import torch
import colossalai
import pytest
import torch.multiprocessing as mp
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
from colossalai.zero.shard_utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_on_exception
from torch.nn.parameter import Parameter
from typing import List
from functools import partial
class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# each parameter is 512 MB
self.p0 = Parameter(torch.empty(1024, 1024, 128))
self.p1 = Parameter(torch.empty(1024, 1024, 128))
self.p2 = Parameter(torch.empty(1024, 1024, 128))
def run_stm():
cuda_capacity = colo_cuda_memory_capacity()
fraction = (1.4 * 1024**3) / cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA
colo_set_process_memory_fraction(fraction)
model = Net()
for p in model.parameters():
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
GLOBAL_MODEL_DATA_TRACER.register_model(model)
mem_collector = MemStatsCollector()
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
for p in model.parameters():
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
mem_collector.start_collection()
# Compute order: 0 1 2 0 1
# warmup
# use naive eviction strategy
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
mem_collector.finish_collection()
mem_collector.reset_sampling_cnter()
stateful_tensor_mgr.reset()
# warmup done
# use OPT-like eviction strategy
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
mem_collector.sample_memstats()
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
stateful_tensor_mgr: StatefulTensorMgr):
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
for p in model.parameters():
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
stateful_tensor_mgr.adjust_layout()
print_stats(model)
device = torch.device(torch.cuda.current_device())
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
for n, p in model.named_parameters():
if hash(p) in cuda_param_after_adjust:
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
else:
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
def print_stats(model: torch.nn.Module):
msgs = []
for n, p in model.named_parameters():
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
print(f'[ {", ".join(msgs)} ]')
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_stm()
@pytest.mark.dist
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_stateful_tensor_manager()