mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
[zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager * add eviction strategy * polish code * polish code * polish comment * add unit test * fix sampler bug * polish code * fix max sampling cnt resetting bug * fix sampler bug * polish code * fix bug * fix unit test Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
@@ -7,6 +7,7 @@ from colossalai.utils import get_current_device
|
|||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
|
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
|
||||||
from ._base_ophook import BaseOpHook
|
from ._base_ophook import BaseOpHook
|
||||||
|
|
||||||
@@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
shard_strategy: BaseShardStrategy,
|
shard_strategy: BaseShardStrategy,
|
||||||
memstarts_collector: Optional[MemStatsCollector],
|
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||||
|
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||||
process_group: Optional[dist.ProcessGroup] = None):
|
process_group: Optional[dist.ProcessGroup] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||||
|
|
||||||
self._memstarts_collector = memstarts_collector
|
self._memstarts_collector = memstarts_collector
|
||||||
|
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
|
if self._stateful_tensor_mgr:
|
||||||
|
self._stateful_tensor_mgr.adjust_layout()
|
||||||
|
else:
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'colo_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
|
||||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
|
||||||
|
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||||
|
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||||
|
|
||||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
@@ -60,19 +71,27 @@ class ZeroHook(BaseOpHook):
|
|||||||
param.colo_attr.remove_torch_payload()
|
param.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
|
||||||
|
if self._stateful_tensor_mgr:
|
||||||
|
self._stateful_tensor_mgr.adjust_layout()
|
||||||
|
else:
|
||||||
|
for param in module.parameters(recurse=False):
|
||||||
|
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||||
|
|
||||||
tensor_list = []
|
tensor_list = []
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
assert hasattr(param, 'colo_attr')
|
assert hasattr(param, 'colo_attr')
|
||||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||||
for param in module.parameters(recurse=False):
|
|
||||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
|
||||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_memstats()
|
self._memstarts_collector.sample_memstats()
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||||
|
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||||
|
|
||||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
@@ -91,4 +110,5 @@ class ZeroHook(BaseOpHook):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def post_iter(self):
|
def post_iter(self):
|
||||||
pass
|
if self._stateful_tensor_mgr:
|
||||||
|
self._stateful_tensor_mgr.reset()
|
||||||
|
@@ -20,10 +20,12 @@ class SamplingCounter:
|
|||||||
assert self._max_sampling_cnt is not None
|
assert self._max_sampling_cnt is not None
|
||||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||||
|
|
||||||
@property
|
def current(self):
|
||||||
def sampling_cnt(self):
|
|
||||||
return self._samplint_cnt
|
return self._samplint_cnt
|
||||||
|
|
||||||
|
def max(self):
|
||||||
|
return self._max_sampling_cnt
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._max_sampling_cnt = self._samplint_cnt
|
self._max_sampling_cnt = self._samplint_cnt
|
||||||
self._samplint_cnt = 0
|
self._samplint_cnt = 0
|
||||||
@@ -50,6 +52,8 @@ class MemStatsCollector:
|
|||||||
self._model_data_cpu_list = []
|
self._model_data_cpu_list = []
|
||||||
self._overall_cpu_list = []
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
|
self._non_model_data_cuda_list = []
|
||||||
|
self._non_model_data_cpu_list = []
|
||||||
self._sampling_time = []
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
@@ -96,18 +100,20 @@ class MemStatsCollector:
|
|||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
|
return [elem / scale for elem in self._non_model_data_cuda_list]
|
||||||
elif device_type == 'cpu':
|
elif device_type == 'cpu':
|
||||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
|
return [elem / scale for elem in self._non_model_data_cpu_list]
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def current_non_model_data(self, device_type: str) -> int:
|
def current_non_model_data(self, device_type: str) -> int:
|
||||||
"""get the non model data of current sampling moment
|
"""get the non model data of the current sampling moment
|
||||||
"""
|
"""
|
||||||
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
|
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
||||||
|
|
||||||
def next_non_model_data(self, device_type: str):
|
def next_non_model_data(self, device_type: str):
|
||||||
|
"""get the non model data of the next sampling moment
|
||||||
|
"""
|
||||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -128,18 +134,20 @@ class MemStatsCollector:
|
|||||||
Advance the sampling cnter.
|
Advance the sampling cnter.
|
||||||
"""
|
"""
|
||||||
if self._start_flag:
|
if self._start_flag:
|
||||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
sampling_cnt = self._sampling_cnter.current()
|
||||||
assert sampling_cnt == len(self._overall_cuda_list)
|
assert sampling_cnt == len(self._overall_cuda_list)
|
||||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||||
|
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||||
|
|
||||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||||
|
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
||||||
# FIXME() cpu sys used should also return from self._mem_monitor()
|
|
||||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
||||||
|
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||||
self._sampling_time.append(time.time())
|
self._sampling_time.append(time.time())
|
||||||
self._mem_monitor.start()
|
self._mem_monitor.start()
|
||||||
|
# TODO(ver217): refactor sampler
|
||||||
|
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
||||||
self._sampling_cnter.advance()
|
self._sampling_cnter.advance()
|
||||||
|
|
||||||
def reset_sampling_cnter(self) -> None:
|
def reset_sampling_cnter(self) -> None:
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from .base_shard_strategy import BaseShardStrategy
|
from .base_shard_strategy import BaseShardStrategy
|
||||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||||
from .tensor_shard_strategy import TensorShardStrategy
|
from .tensor_shard_strategy import TensorShardStrategy
|
||||||
|
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
|
||||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
|
||||||
|
@@ -1,26 +1,43 @@
|
|||||||
|
import functools
|
||||||
import torch
|
import torch
|
||||||
from colossalai.context.singleton_meta import SingletonMeta
|
import types
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||||
from typing import Set
|
from typing import Dict, List
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
|
||||||
class StatefulTensorMgr(SingletonMeta):
|
class StatefulTensorMgr(object):
|
||||||
_stateful_tensor_list: Set[ShardedParamV2] = set()
|
"""
|
||||||
|
Stateful Tensor Manager, inspired from PatrickStar
|
||||||
|
|
||||||
def register_param(self, param: ShardedParamV2) -> None:
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
|
https://arxiv.org/abs/2108.05818
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
|
||||||
|
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||||
|
self._mem_stats_collector = mem_stats_collector
|
||||||
|
self._logger = get_dist_logger("StatefulTensorMgr")
|
||||||
|
|
||||||
|
self._warmup = True
|
||||||
|
self._warmup_cuda_available_ratio = 0.2
|
||||||
|
|
||||||
|
self._compute_list: List[StatefulTensor] = []
|
||||||
|
self._compute_idx: int = -1
|
||||||
|
|
||||||
|
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
||||||
|
assert isinstance(param, ShardedParamV2)
|
||||||
for t in param.get_payload_tensors():
|
for t in param.get_payload_tensors():
|
||||||
assert isinstance(t, StatefulTensor)
|
assert isinstance(t, StatefulTensor)
|
||||||
self._stateful_tensor_list.add(t)
|
self._stateful_tensor_list.append(t)
|
||||||
|
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
||||||
|
|
||||||
def evict_tensors(self) -> None:
|
def adjust_layout(self) -> None:
|
||||||
pass
|
|
||||||
|
|
||||||
def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
|
|
||||||
""" Adjust the layout of statefuil tensor according to the information provided
|
""" Adjust the layout of statefuil tensor according to the information provided
|
||||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||||
|
|
||||||
@@ -41,29 +58,62 @@ class StatefulTensorMgr(SingletonMeta):
|
|||||||
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
||||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||||
hold_cuda_tensor_list.append(tensor)
|
hold_cuda_tensor_list.append(tensor)
|
||||||
else:
|
elif tensor.device.type == 'cpu':
|
||||||
if tensor.state == TensorState.COMPUTE:
|
if tensor.state == TensorState.COMPUTE:
|
||||||
move_to_cuda_tensor_list.append(tensor)
|
move_to_cuda_tensor_list.append(tensor)
|
||||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0]
|
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||||
|
else:
|
||||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
raise RuntimeError
|
||||||
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
|
|
||||||
mem_stats_collector.next_non_model_data('cuda'))
|
|
||||||
cuda_capacity = colo_cuda_memory_capacity()
|
cuda_capacity = colo_cuda_memory_capacity()
|
||||||
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
|
|
||||||
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
|
|
||||||
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
|
|
||||||
# Here use a naive eviction strategy.
|
|
||||||
acc_size = 0
|
|
||||||
for t in hold_cuda_tensor_list:
|
|
||||||
if acc_size > cuda_demand:
|
|
||||||
break
|
|
||||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
|
||||||
t_size = colo_tensor_mem_usage(t)
|
|
||||||
acc_size += t_size
|
|
||||||
if acc_size < cuda_demand:
|
|
||||||
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")
|
|
||||||
|
|
||||||
|
if self._warmup:
|
||||||
|
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||||
|
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||||
|
else:
|
||||||
|
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||||
|
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
|
||||||
|
self._mem_stats_collector.next_non_model_data('cuda'))
|
||||||
|
|
||||||
|
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||||
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||||
|
|
||||||
|
if avail_cuda_model_data < cuda_demand:
|
||||||
|
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||||
|
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||||
|
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
|
||||||
# move COMPUTE tensors to CUDA
|
# move COMPUTE tensors to CUDA
|
||||||
for t in move_to_cuda_tensor_list:
|
for t in move_to_cuda_tensor_list:
|
||||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""This function must be called when each iteration finishes
|
||||||
|
"""
|
||||||
|
self._warmup = False
|
||||||
|
self._compute_idx = -1
|
||||||
|
|
||||||
|
def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
|
||||||
|
freed_cuda_model_data = 0
|
||||||
|
to_free_tensor_list = hold_cuda_tensor_list
|
||||||
|
if not self._warmup:
|
||||||
|
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
|
||||||
|
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
|
||||||
|
if self._compute_list[i] in next_compute_idx:
|
||||||
|
next_compute_idx[self._compute_list[i]] = i
|
||||||
|
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||||
|
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||||
|
for t in to_free_tensor_list:
|
||||||
|
if freed_cuda_model_data > to_free_cuda_model_data:
|
||||||
|
break
|
||||||
|
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||||
|
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||||
|
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||||
|
trans_state_func(state)
|
||||||
|
if state == TensorState.COMPUTE:
|
||||||
|
self._compute_idx += 1
|
||||||
|
if self._warmup:
|
||||||
|
self._compute_list.append(stateful_tensor)
|
||||||
|
@@ -23,6 +23,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
@@ -36,7 +37,6 @@ class ShardedModelV2(nn.Module):
|
|||||||
|
|
||||||
Note:
|
Note:
|
||||||
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
|
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
|
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
|
||||||
if you enable ``reuse_fp16_shard``.
|
if you enable ``reuse_fp16_shard``.
|
||||||
@@ -106,12 +106,21 @@ class ShardedModelV2(nn.Module):
|
|||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||||
self._memstats_collector = MemStatsCollector()
|
self._memstats_collector = MemStatsCollector()
|
||||||
|
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
||||||
|
# for param in module.parameters():
|
||||||
|
for submodule in module.modules():
|
||||||
|
for param in submodule.parameters(recurse=False):
|
||||||
|
if hasattr(param, 'colo_attr'):
|
||||||
|
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||||
else:
|
else:
|
||||||
self._memstats_collector = None
|
self._memstats_collector = None
|
||||||
|
self._stateful_tensor_mgr = None
|
||||||
self._iter_cnter = 0
|
self._iter_cnter = 0
|
||||||
|
|
||||||
# Register hooks
|
# Register hooks
|
||||||
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
self._ophook_list = [
|
||||||
|
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||||
|
]
|
||||||
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
||||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||||
@@ -138,6 +147,9 @@ class ShardedModelV2(nn.Module):
|
|||||||
self._cuda_margin_space = 0
|
self._cuda_margin_space = 0
|
||||||
self.reuse_fp16_shard = reuse_fp16_shard
|
self.reuse_fp16_shard = reuse_fp16_shard
|
||||||
|
|
||||||
|
def adjust_stateful_tensor_layout(self) -> None:
|
||||||
|
self._stateful_tensor_mgr.adjust_layout()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_memory_tracer(self):
|
def use_memory_tracer(self):
|
||||||
return self._use_memory_tracer
|
return self._use_memory_tracer
|
||||||
@@ -150,20 +162,15 @@ class ShardedModelV2(nn.Module):
|
|||||||
def cpu_offload(self):
|
def cpu_offload(self):
|
||||||
return self._cpu_offload
|
return self._cpu_offload
|
||||||
|
|
||||||
def dump_memory_stats(self, filename: str = 'dump_mem_stats.log') -> None:
|
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
||||||
"""Dummy memory tracer collected infomation to a file.
|
"""
|
||||||
|
dummy memory tracer collected infomation to a file.
|
||||||
Example::
|
try:
|
||||||
|
# forward: model(inputs)
|
||||||
try:
|
# backward: optimizer.backward()
|
||||||
# forward: model(inputs)
|
except Exception as e:
|
||||||
# backward: optimizer.backward()
|
model.dump_memory_stats()
|
||||||
except Exception as e:
|
exit(0)
|
||||||
model.dump_memory_stats()
|
|
||||||
exit(0)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filename (str, optional): Output file name. Defaults to 'dump_mem_stats.log'.
|
|
||||||
"""
|
"""
|
||||||
if self._use_memory_tracer:
|
if self._use_memory_tracer:
|
||||||
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
||||||
@@ -172,12 +179,12 @@ class ShardedModelV2(nn.Module):
|
|||||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
||||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
||||||
f.write('CUDA model data (GB)\n')
|
f.write('CUDA model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.model_data_cuda_list('cuda', 'GB')))
|
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
f.write('CUDA non model data (GB)\n')
|
f.write('CUDA non model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.non_model_data_cuda_list('cuda', 'GB')))
|
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
|
||||||
f.write('CPU non model data (GB)\n')
|
f.write('CPU non model data (GB)\n')
|
||||||
f.write(str(self._memstats_collector.non_model_data_cuda_list('cpu', 'GB')))
|
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
||||||
f.write('\n')
|
f.write('\n')
|
||||||
|
|
||||||
def _pre_forward_operations(self):
|
def _pre_forward_operations(self):
|
||||||
|
@@ -350,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||||
|
|
||||||
if not is_param_sharded and not self.keep_unshard:
|
if not is_param_sharded and not self.keep_unshard:
|
||||||
# We gather full fp16 param here
|
# We gather full fp16 param here
|
||||||
|
@@ -26,7 +26,7 @@ class ShardedParamV2(object):
|
|||||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||||
"""returns stateful tensors kept by this class.
|
"""returns stateful tensors kept by this class.
|
||||||
"""
|
"""
|
||||||
return [self._sharded_data_tensor, self.saved_grad]
|
return [self._sharded_data_tensor]
|
||||||
|
|
||||||
def remove_torch_payload(self):
|
def remove_torch_payload(self):
|
||||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||||
|
112
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
Normal file
112
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
import pytest
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
|
||||||
|
from colossalai.zero.shard_utils import StatefulTensorMgr
|
||||||
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from colossalai.testing import rerun_on_exception
|
||||||
|
from torch.nn.parameter import Parameter
|
||||||
|
from typing import List
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
|
||||||
|
class Net(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
# each parameter is 512 MB
|
||||||
|
self.p0 = Parameter(torch.empty(1024, 1024, 128))
|
||||||
|
self.p1 = Parameter(torch.empty(1024, 1024, 128))
|
||||||
|
self.p2 = Parameter(torch.empty(1024, 1024, 128))
|
||||||
|
|
||||||
|
|
||||||
|
def run_stm():
|
||||||
|
cuda_capacity = colo_cuda_memory_capacity()
|
||||||
|
fraction = (1.4 * 1024**3) / cuda_capacity
|
||||||
|
# limit max memory to 1.4GB
|
||||||
|
# which means only 2 parameters can be on CUDA
|
||||||
|
colo_set_process_memory_fraction(fraction)
|
||||||
|
model = Net()
|
||||||
|
for p in model.parameters():
|
||||||
|
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||||
|
mem_collector = MemStatsCollector()
|
||||||
|
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||||
|
for p in model.parameters():
|
||||||
|
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||||
|
|
||||||
|
mem_collector.start_collection()
|
||||||
|
# Compute order: 0 1 2 0 1
|
||||||
|
# warmup
|
||||||
|
# use naive eviction strategy
|
||||||
|
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
mem_collector.finish_collection()
|
||||||
|
mem_collector.reset_sampling_cnter()
|
||||||
|
stateful_tensor_mgr.reset()
|
||||||
|
|
||||||
|
# warmup done
|
||||||
|
# use OPT-like eviction strategy
|
||||||
|
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||||
|
mem_collector.sample_memstats()
|
||||||
|
|
||||||
|
|
||||||
|
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
||||||
|
stateful_tensor_mgr: StatefulTensorMgr):
|
||||||
|
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||||
|
for p in model.parameters():
|
||||||
|
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
|
||||||
|
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||||
|
stateful_tensor_mgr.adjust_layout()
|
||||||
|
print_stats(model)
|
||||||
|
device = torch.device(torch.cuda.current_device())
|
||||||
|
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
if hash(p) in cuda_param_after_adjust:
|
||||||
|
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
|
||||||
|
else:
|
||||||
|
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
|
||||||
|
|
||||||
|
|
||||||
|
def print_stats(model: torch.nn.Module):
|
||||||
|
msgs = []
|
||||||
|
for n, p in model.named_parameters():
|
||||||
|
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
|
||||||
|
print(f'[ {", ".join(msgs)} ]')
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_stm()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
|
def test_stateful_tensor_manager(world_size=1):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_stateful_tensor_manager()
|
Reference in New Issue
Block a user