mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-18 07:31:19 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
9
colossalai/legacy/zero/gemini/__init__.py
Normal file
9
colossalai/legacy/zero/gemini/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .ophooks import BaseOpHook, register_ophooks_recursively
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
|
||||
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
|
||||
]
|
48
colossalai/legacy/zero/gemini/gemini_context.py
Normal file
48
colossalai/legacy/zero/gemini/gemini_context.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class GeminiMemoryManager(object):
|
||||
|
||||
def __init__(self, states_cls: EnumMeta):
|
||||
super().__init__()
|
||||
self.states_cls = states_cls
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem = dict()
|
||||
self.state_mem = dict()
|
||||
self.state_mem['cpu'] = dict()
|
||||
self.state_mem['cuda'] = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def total_number(self):
|
||||
return self._cnter
|
||||
|
||||
def reset(self):
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
|
||||
# memory conditions for all states
|
||||
for state in self.states_cls:
|
||||
self.state_mem['cpu'][state] = 0
|
||||
self.state_mem['cuda'][state] = 0
|
||||
|
||||
def register_new_instance(self):
|
||||
self._cnter += 1
|
||||
|
||||
def delete_instance(self):
|
||||
self._cnter -= 1
|
||||
|
||||
def print_info(self):
|
||||
print(f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||
sep='\n')
|
||||
|
||||
for state in self.states_cls:
|
||||
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||
sep='\n')
|
3
colossalai/legacy/zero/gemini/ophooks/__init__.py
Normal file
3
colossalai/legacy/zero/gemini/ophooks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .utils import BaseOpHook, register_ophooks_recursively
|
||||
|
||||
__all__ = ["BaseOpHook", "register_ophooks_recursively"]
|
32
colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py
Normal file
32
colossalai/legacy/zero/gemini/ophooks/_shard_grad_ophook.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardGradMemTracerHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and after FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, '_sharded_grad')
|
||||
param._sharded_grad.setup()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
48
colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py
Normal file
48
colossalai/legacy/zero/gemini/ophooks/_shard_param_ophook.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardParamHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and after FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def niter(self):
|
||||
return self._niter
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
145
colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py
Normal file
145
colossalai/legacy/zero/gemini/ophooks/runtime_mem_tracer_hook.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.zero.gemini.tensor_utils import alloc_storage, free_storage
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradMemStats():
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
def clear(self):
|
||||
self.unreleased_grad_flag.clear()
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
class GradMemTracerHook():
|
||||
|
||||
def __init__(self, grad_stats: GradMemStats):
|
||||
self.grad_hook_list = []
|
||||
self._grad_stats = grad_stats
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
assert self._grad_stats.unreleased_grad_flag[p]
|
||||
free_storage(grad)
|
||||
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def register_grad_hook(self, module: torch.nn.Module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
hook.remove()
|
||||
|
||||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._memstats = memstats
|
||||
self._grad_stats = gradstats
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
|
||||
def _free_cuda_params(self, params):
|
||||
for p in params:
|
||||
if p.data.device.type == "cpu":
|
||||
raise NotImplementedError("Only free cuda memory")
|
||||
free_storage(p.data)
|
||||
|
||||
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
|
||||
"""
|
||||
move params to cuda
|
||||
|
||||
Args:
|
||||
params (List[torch.nn.Parameter]): target params
|
||||
|
||||
Raises:
|
||||
NotImplementedError: raise error when param has cpu grad
|
||||
"""
|
||||
for p in params:
|
||||
cur_dev = p.data.device.type
|
||||
if cur_dev == "cpu":
|
||||
if p.grad is not None and p.grad.device.type == "cpu":
|
||||
raise NotImplementedError("Only run in forward propagation")
|
||||
p.data = torch.empty(p.data.shape,
|
||||
device="cuda",
|
||||
dtype=p.data.dtype,
|
||||
requires_grad=p.data.requires_grad)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
def record_model_data_volume(self, params):
|
||||
"""
|
||||
get cuda model data used by params
|
||||
"""
|
||||
data_volume = self._grad_stats.unreleased_grad_volume
|
||||
for p in params:
|
||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||
data_volume += cur_model_data_volume
|
||||
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
|
||||
# add param.grad, actually param.grad is None in this time
|
||||
data_volume += cur_model_data_volume
|
||||
if not self._grad_stats.unreleased_grad_flag[p]:
|
||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||
self._grad_stats.unreleased_grad_flag[p] = True
|
||||
# record max non model data used for this Op
|
||||
self._memstats.record_max_cuda_model_data(data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
max_cuda_used_pre_op = self.mem_monitor.finish()
|
||||
# record max cuda overall data for prev OP.
|
||||
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
|
||||
# record max cuda non model data for prev OP.
|
||||
self._memstats.calc_max_cuda_non_model_data()
|
||||
|
||||
self._allocate_params_on_cuda(params)
|
||||
# record max cuda model data for current OP
|
||||
self.record_model_data_volume(params)
|
||||
|
||||
self.mem_monitor.start()
|
||||
self._memstats.increase_preop_step(params)
|
||||
|
||||
def post_op(self, params):
|
||||
self._free_cuda_params(params)
|
||||
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
self.pre_op(params)
|
||||
|
||||
def post_forward(self, params: List[torch.Tensor]) -> None:
|
||||
self.post_op(params)
|
||||
|
||||
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
||||
self.pre_op(params)
|
||||
|
||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
self.post_op(params)
|
||||
|
||||
@contextmanager
|
||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||
old_training_phase = self._training_phase
|
||||
try:
|
||||
self._training_phase = training_phase
|
||||
yield
|
||||
finally:
|
||||
self._training_phase = old_training_phase
|
||||
|
||||
switch_to_backward = switch_training_phase
|
||||
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)
|
142
colossalai/legacy/zero/gemini/ophooks/utils.py
Normal file
142
colossalai/legacy/zero/gemini/ophooks/utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from filtered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
3
colossalai/legacy/zero/gemini/paramhooks/__init__.py
Normal file
3
colossalai/legacy/zero/gemini/paramhooks/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from ._param_hookmgr import BaseParamHookMgr
|
||||
|
||||
__all__ = ["BaseParamHookMgr"]
|
39
colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py
Normal file
39
colossalai/legacy/zero/gemini/paramhooks/_param_hookmgr.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseParamHookMgr(object):
|
||||
|
||||
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
register backward hook on every parameters of module
|
||||
"""
|
||||
self._param_list = param_list
|
||||
self._hook_list = []
|
||||
|
||||
def register_backward_hooks(self, hook_call: Callable) -> None:
|
||||
r"""
|
||||
The hook_call will be called every time a gradient with respect to the a param in self.param_list
|
||||
is computed.
|
||||
The hook should have the following signature:
|
||||
```
|
||||
hook(param, grad) -> Tensor or None
|
||||
```
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
||||
handle = p.register_hook(functools.partial(hook_call, p))
|
||||
p._base_param_hook = handle
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""
|
||||
Remove hooks from model parameters.
|
||||
"""
|
||||
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
||||
p._base_param_hook.remove()
|
209
colossalai/legacy/zero/gemini/stateful_tensor.py
Normal file
209
colossalai/legacy/zero/gemini/stateful_tensor.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .gemini_context import GeminiMemoryManager
|
||||
|
||||
|
||||
def sizeof_tensor(tensor: torch.Tensor):
|
||||
return tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
# Global Stateful Tensor Manager
|
||||
GST_MGR = GeminiMemoryManager(TensorState)
|
||||
|
||||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = None
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
|
||||
StatefulTensor.GST_MGR.register_new_instance()
|
||||
|
||||
if self._state == TensorState.FREE:
|
||||
# when the state is free, payload should be None
|
||||
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
|
||||
else:
|
||||
# otherwise, payload should not be None
|
||||
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
|
||||
self._payload = maybe_tensor
|
||||
self._payload_size = sizeof_tensor(maybe_tensor)
|
||||
self.__trans_state_update(TensorState.FREE, state)
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return self._payload.data_ptr()
|
||||
|
||||
def set_null(self) -> None:
|
||||
# notice that free stateful tensor do not need to become null again
|
||||
if self.state != TensorState.FREE:
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
self.__release()
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self.state == TensorState.FREE:
|
||||
# check sanity here
|
||||
assert self.payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
if self.state == TensorState.FREE:
|
||||
# free stateful tensor can't change state
|
||||
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
|
||||
return
|
||||
|
||||
self.__trans_state_update(self.state, state)
|
||||
|
||||
if state == TensorState.FREE:
|
||||
self.__release()
|
||||
else:
|
||||
self._state = state
|
||||
|
||||
def move_to(self, device: Union[torch.device, int]):
|
||||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||
|
||||
if not isinstance(device, torch.device):
|
||||
to_device = torch.device('cuda', device)
|
||||
else:
|
||||
to_device = device
|
||||
|
||||
from_device_type = self.device.type
|
||||
if from_device_type == to_device.type:
|
||||
# from device == to device
|
||||
return
|
||||
|
||||
# update manager's information
|
||||
self.__trans_device_update(from_device_type, to_device.type)
|
||||
self.payload.data = self.payload.data.to(to_device)
|
||||
|
||||
def payload_copy(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def payload_reset(self, tensor) -> None:
|
||||
|
||||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||
|
||||
if self.payload is not None:
|
||||
# release old payload
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
else:
|
||||
# otherwise, set the state to HOLD for new payload
|
||||
self._state = TensorState.HOLD
|
||||
del self._payload
|
||||
|
||||
self._payload = tensor
|
||||
self._payload_size = sizeof_tensor(tensor)
|
||||
# record new payload
|
||||
self.__trans_state_update(TensorState.FREE, self.state)
|
||||
|
||||
def payload_relay(self, rhs):
|
||||
# relay the payload of rhs to current stateful tensor
|
||||
# can't support null relay right now
|
||||
assert not rhs.is_null()
|
||||
|
||||
# now this function only support stateful tensor that has zero-length payload
|
||||
# because it doesn't require memory manager updating
|
||||
# you can extend this function by yourself
|
||||
assert self.payload_size == 0
|
||||
|
||||
self._payload = rhs.payload
|
||||
self._payload_size = rhs.payload_size
|
||||
self._state = TensorState.HOLD
|
||||
self.__trans_state_update(rhs.state, TensorState.HOLD)
|
||||
|
||||
rhs.__release()
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
@property
|
||||
def payload_size(self) -> int:
|
||||
return self._payload_size
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
|
||||
|
||||
def __release(self):
|
||||
# release current payload
|
||||
# shouldn't be visible to users
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
self._payload_size = 0
|
||||
|
||||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||
"""Update global manager when changing the state of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
device_type = self.device.type
|
||||
|
||||
if from_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][from_state] -= size
|
||||
else:
|
||||
# when from_state is FREE, the tensor is new to manager
|
||||
# we should add its memory
|
||||
manager.total_mem[device_type] += size
|
||||
|
||||
if to_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][to_state] += size
|
||||
else:
|
||||
# when to_state is FREE, the tensor will be deleted soon
|
||||
# we should sub its memory
|
||||
manager.total_mem[device_type] -= size
|
||||
|
||||
def __trans_device_update(self, from_type: str, to_type: str):
|
||||
"""Update global manager when changing the device of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
state = self.state
|
||||
|
||||
# update aggregated information
|
||||
manager.total_mem[from_type] -= size
|
||||
manager.total_mem[to_type] += size
|
||||
|
||||
# update the information of each state
|
||||
manager.state_mem[from_type][state] -= size
|
||||
manager.state_mem[to_type][state] += size
|
||||
|
||||
def __del__(self):
|
||||
self.set_null()
|
||||
StatefulTensor.GST_MGR.delete_instance()
|
||||
del self
|
103
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
Normal file
103
colossalai/legacy/zero/gemini/stateful_tensor_mgr.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import functools
|
||||
import types
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .stateful_tensor import StatefulTensor, TensorState
|
||||
from .tensor_placement_policy import TensorPlacementPolicy
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
||||
"""
|
||||
Stateful Tensor Manager, inspired from PatrickStar
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
|
||||
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||
|
||||
self._compute_list: List[StatefulTensor] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._warmup = True
|
||||
|
||||
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
|
||||
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
|
||||
self._stateful_tensor_list = tensor_list
|
||||
for t in self._stateful_tensor_list:
|
||||
assert isinstance(t, StatefulTensor)
|
||||
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
||||
|
||||
def start_iter(self):
|
||||
pass
|
||||
|
||||
def finish_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of stateful tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
|
||||
start = time()
|
||||
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
self._cpu_gpu_move_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
|
||||
@property
|
||||
def cpu_gpu_move_volume(self):
|
||||
return self._cpu_gpu_move_volume
|
||||
|
||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||
trans_state_func(state)
|
||||
if state == TensorState.COMPUTE:
|
||||
self._compute_idx += 1
|
||||
if self._warmup:
|
||||
self._compute_list.append(stateful_tensor)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool):
|
||||
move_to_cuda_tensor_list = []
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
raise RuntimeError
|
||||
return move_to_cuda_tensor_list, hold_cuda_tensor_list
|
139
colossalai/legacy/zero/gemini/tensor_placement_policy.py
Normal file
139
colossalai/legacy/zero/gemini/tensor_placement_policy.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class TensorPlacementPolicy(ABC):
|
||||
|
||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
self.device: Optional[torch.device] = device
|
||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
volume = 0
|
||||
for t in hold_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, self.device)
|
||||
volume += t.payload.numel() * t.payload.element_size()
|
||||
return volume, 0
|
||||
|
||||
|
||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
return 0, 0
|
||||
|
||||
|
||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# TODO(ver217): make these args configurable
|
||||
self._warmup_non_model_data_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def evict_tensors(self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[StatefulTensor] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
Args:
|
||||
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||
compute_idx (int, optional): the idx of computing device. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
|
||||
Returns:
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
end = time()
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not warmup:
|
||||
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
|
||||
tuple(compute_list))
|
||||
# print(self._sort_hold_cuda_tensors.cache_info())
|
||||
end = time()
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += t.payload_size
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
return freed_cuda_model_data, end - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
if compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
|
||||
class TensorPlacementPolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
|
||||
if policy_name == 'cpu':
|
||||
return CPUTensorPlacementPolicy
|
||||
elif policy_name == 'cuda':
|
||||
return CUDATensorPlacementPolicy
|
||||
elif policy_name == 'auto':
|
||||
return AutoTensorPlacementPolicy
|
||||
else:
|
||||
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
120
colossalai/legacy/zero/gemini/tensor_utils.py
Normal file
120
colossalai/legacy/zero/gemini/tensor_utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
t = tensor.payload
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
t = tensor
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
cuda_use, cpu_use = 0, 0
|
||||
|
||||
mem_use = t.storage().size() * t.element_size()
|
||||
if t.device.type == 'cuda':
|
||||
cuda_use += mem_use
|
||||
elif t.device.type == 'cpu':
|
||||
cpu_use += mem_use
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
|
||||
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
|
||||
torch.Tensor]) -> None:
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
|
||||
NOTE() The source tensor payload will be removed after this function.
|
||||
|
||||
The function will record the communication volume between CPU and GPU.
|
||||
Args:
|
||||
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
||||
"""
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t_payload = src_t.payload
|
||||
else:
|
||||
src_t_payload = src_t.data
|
||||
src_dev = src_t_payload.device
|
||||
|
||||
if isinstance(tgt_t, StatefulTensor):
|
||||
tgt_t_payload = tgt_t.payload
|
||||
else:
|
||||
tgt_t_payload = tgt_t.data
|
||||
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t.set_null()
|
||||
else:
|
||||
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||
int]) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||
target_device: a target device, if type is int, it the index of cuda card.
|
||||
"""
|
||||
if not isinstance(target_device, torch.device):
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.to(target_device)
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(target_device)
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
"""colo_model_data_move_to_cpu
|
||||
move a model data tensor from gpu to cpu
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): _description_
|
||||
"""
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.cpu()
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(torch.device('cpu'))
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Clone a model data tensor
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
||||
target_device (torch.device): the target device
|
||||
Returns:
|
||||
torch.Tensor: a cloned torch tensor
|
||||
"""
|
||||
# TODO() rename this function
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t_payload = t.payload if isinstance(t, StatefulTensor) else t
|
||||
return t_payload
|
Reference in New Issue
Block a user