[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
This commit is contained in:
Hongxin Liu
2023-09-18 16:31:06 +08:00
committed by GitHub
parent 32e7f99416
commit b5f9e37c70
342 changed files with 2919 additions and 4182 deletions

View File

@@ -3,7 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch
from torch import nn
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.legacy.tensor import ProcessGroup
from colossalai.tensor import ColoParameter, ColoTensor
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica

View File

@@ -3,9 +3,8 @@ from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats',
'OrderedParamGenerator'
]

View File

@@ -1,7 +1,6 @@
from typing import Optional
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats
@@ -33,4 +32,5 @@ class ChunkMemStatsCollector(MemStatsCollector):
@property
def cuda_margin_mem(self) -> float:
from colossalai.legacy.utils.memory import colo_device_memory_capacity
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda

View File

@@ -5,7 +5,7 @@ from time import sleep, time
import torch
from colossalai.utils import colo_device_memory_used, get_current_device
from colossalai.utils import get_current_device
class MemoryMonitor:
@@ -110,6 +110,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
return max_usage
def _measure_usage(self):
from colossalai.legacy.utils import colo_device_memory_used
max_usage = 0
while self.keep_measuring:
max_usage = max(

View File

@@ -70,7 +70,7 @@ class MemStatsCollector:
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
from colossalai.zero.legacy.gemini import StatefulTensor
from colossalai.legacy.zero.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']

View File

@@ -1,12 +1,12 @@
import torch.nn
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float
from .memory_stats import MemStats

View File

@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager

View File

@@ -1,45 +0,0 @@
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy'
]

View File

@@ -1,9 +0,0 @@
from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
__all__ = [
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
]

View File

@@ -1,48 +0,0 @@
from enum import EnumMeta
class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances
self.total_mem = dict()
self.state_mem = dict()
self.state_mem['cpu'] = dict()
self.state_mem['cuda'] = dict()
self.reset()
@property
def total_number(self):
return self._cnter
def reset(self):
self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states
for state in self.states_cls:
self.state_mem['cpu'][state] = 0
self.state_mem['cuda'][state] = 0
def register_new_instance(self):
self._cnter += 1
def delete_instance(self):
self._cnter -= 1
def print_info(self):
print(f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
sep='\n')
for state in self.states_cls:
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
sep='\n')

View File

@@ -1,3 +0,0 @@
from .utils import BaseOpHook, register_ophooks_recursively
__all__ = ["BaseOpHook", "register_ophooks_recursively"]

View File

@@ -1,32 +0,0 @@
import torch
from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardGradMemTracerHook(BaseOpHook):
"""
A hook to process sharded param before and after FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, '_sharded_grad')
param._sharded_grad.setup()
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
def post_iter(self):
pass

View File

@@ -1,48 +0,0 @@
import torch
from colossalai.legacy.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardParamHook(BaseOpHook):
"""
A hook to process sharded param before and after FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_iter(self):
pass
def post_iter(self):
pass

View File

@@ -1,145 +0,0 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
import torch
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
class GradMemStats():
def __init__(self) -> None:
self.unreleased_grad_flag = {}
self.unreleased_grad_volume = 0
def clear(self):
self.unreleased_grad_flag.clear()
self.unreleased_grad_volume = 0
class GradMemTracerHook():
def __init__(self, grad_stats: GradMemStats):
self.grad_hook_list = []
self._grad_stats = grad_stats
def grad_handle(self, p, grad):
assert self._grad_stats.unreleased_grad_flag[p]
free_storage(grad)
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
self._grad_stats.unreleased_grad_flag[p] = False
def register_grad_hook(self, module: torch.nn.Module):
for p in module.parameters():
if p.requires_grad:
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
self._grad_stats.unreleased_grad_flag[p] = False
def remove_grad_hook(self):
for hook in self.grad_hook_list:
hook.remove()
class ParamMemTracerHook(ColoParamOpHook):
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
super().__init__()
self._training_phase = TrainingPhase.FORWARD
self._memstats = memstats
self._grad_stats = gradstats
self.mem_monitor = SyncCudaMemoryMonitor()
def _free_cuda_params(self, params):
for p in params:
if p.data.device.type == "cpu":
raise NotImplementedError("Only free cuda memory")
free_storage(p.data)
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for p in params:
cur_dev = p.data.device.type
if cur_dev == "cpu":
if p.grad is not None and p.grad.device.type == "cpu":
raise NotImplementedError("Only run in forward propagation")
p.data = torch.empty(p.data.shape,
device="cuda",
dtype=p.data.dtype,
requires_grad=p.data.requires_grad)
elif cur_dev == "cuda":
alloc_storage(p.data)
def record_model_data_volume(self, params):
"""
get cuda model data used by params
"""
data_volume = self._grad_stats.unreleased_grad_volume
for p in params:
cur_model_data_volume = p.data.numel() * p.data.element_size()
data_volume += cur_model_data_volume
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
# add param.grad, actually param.grad is None in this time
data_volume += cur_model_data_volume
if not self._grad_stats.unreleased_grad_flag[p]:
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
self._grad_stats.unreleased_grad_flag[p] = True
# record max non model data used for this Op
self._memstats.record_max_cuda_model_data(data_volume)
def pre_op(self, params):
max_cuda_used_pre_op = self.mem_monitor.finish()
# record max cuda overall data for prev OP.
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
# record max cuda non model data for prev OP.
self._memstats.calc_max_cuda_non_model_data()
self._allocate_params_on_cuda(params)
# record max cuda model data for current OP
self.record_model_data_volume(params)
self.mem_monitor.start()
self._memstats.increase_preop_step(params)
def post_op(self, params):
self._free_cuda_params(params)
def pre_forward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_forward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
def pre_backward(self, params: List[torch.Tensor]) -> None:
self.pre_op(params)
def post_backward(self, params: List[torch.Tensor]) -> None:
self.post_op(params)
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
old_training_phase = self._training_phase
try:
self._training_phase = training_phase
yield
finally:
self._training_phase = old_training_phase
switch_to_backward = switch_training_phase
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)

View File

@@ -1,142 +0,0 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
import torch
class BaseOpHook(ABC):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def __init__(self):
pass
@abstractmethod
def pre_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def post_fwd_exec(self, module: torch.nn.Module, *args):
pass
@abstractmethod
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
pass
@abstractmethod
def post_bwd_exec(self, module: torch.nn.Module, input):
pass
@abstractmethod
def post_iter(self):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if type(outputs) is tuple:
touched_outputs = []
for output in outputs:
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
touched_outputs.append(touched_output)
return tuple(touched_outputs)
elif type(outputs) is torch.Tensor:
return functional.apply(module, backward_function, outputs)
else:
return outputs
class PreBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, outputs):
ctx.module = module
ctx.pre_backward_function = pre_backward_function
module.applied_pre_backward = False
outputs = outputs.detach()
return outputs
@staticmethod
def backward(ctx, *args):
ctx.pre_backward_function(ctx.module)
return (None, None) + args
class PostBackwardFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, module, pre_backward_function, output):
ctx.module = module
output = output.detach()
ctx.pre_backward_function = pre_backward_function
return output
@staticmethod
def backward(ctx, *args):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx.pre_backward_function(ctx.module)
return (None, None) + args
def register_ophooks_recursively(module: torch.nn.Module,
ophook_list: List[BaseOpHook],
name: str = "",
filter_fn: Optional[Callable] = None):
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
assert isinstance(module, torch.nn.Module)
assert isinstance(ophook_list, (list, tuple))
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
for hook in ophook_list:
assert (isinstance(hook, BaseOpHook))
# Add hooks for submodules
for child_name, child in module.named_children():
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
# Early return on modules with no parameters.
if len(list(module.parameters(recurse=False))) == 0:
return
# return from filtered module
if filter_fn is not None and filter_fn(module):
return
def _pre_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_fwd_exec(submodule, *args)
def _post_forward_module_hook(submodule, *args):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_fwd_exec(submodule, *args)
def _pre_backward_module_hook(submodule, inputs, output):
def _run_before_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.pre_bwd_exec(submodule, inputs, output)
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
def _post_backward_module_hook(submodule, inputs):
def _run_after_backward_function(submodule):
for hook in ophook_list:
assert isinstance(submodule, torch.nn.Module)
hook.post_bwd_exec(submodule, inputs)
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
module.register_forward_pre_hook(_pre_forward_module_hook)
module.register_forward_hook(_post_forward_module_hook)
module.register_forward_hook(_pre_backward_module_hook)
module.register_forward_pre_hook(_post_backward_module_hook)

View File

@@ -1,3 +0,0 @@
from ._param_hookmgr import BaseParamHookMgr
__all__ = ["BaseParamHookMgr"]

View File

@@ -1,39 +0,0 @@
import functools
from typing import Callable, List
import torch
class BaseParamHookMgr(object):
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
r"""
register backward hook on every parameters of module
"""
self._param_list = param_list
self._hook_list = []
def register_backward_hooks(self, hook_call: Callable) -> None:
r"""
The hook_call will be called every time a gradient with respect to the a param in self.param_list
is computed.
The hook should have the following signature:
```
hook(param, grad) -> Tensor or None
```
"""
if not torch.is_grad_enabled():
return # don't register grad hooks if grad isn't enabled
for p in self._param_list:
if p.requires_grad and not hasattr(p, '_base_param_hook'):
handle = p.register_hook(functools.partial(hook_call, p))
p._base_param_hook = handle
def remove_hooks(self) -> None:
"""
Remove hooks from model parameters.
"""
for p in self._param_list:
if p.requires_grad and hasattr(p, '_base_param_hook'):
p._base_param_hook.remove()

View File

@@ -1,209 +0,0 @@
from enum import Enum
from typing import Optional, Union
import torch
from .gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
return tensor.numel() * tensor.element_size()
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload
StatefulTensor.GST_MGR.register_new_instance()
if self._state == TensorState.FREE:
# when the state is free, payload should be None
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
else:
# otherwise, payload should not be None
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
self._payload = maybe_tensor
self._payload_size = sizeof_tensor(maybe_tensor)
self.__trans_state_update(TensorState.FREE, state)
def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()
def set_null(self) -> None:
# notice that free stateful tensor do not need to become null again
if self.state != TensorState.FREE:
self.__trans_state_update(self.state, TensorState.FREE)
self.__release()
def is_null(self) -> bool:
if self.state == TensorState.FREE:
# check sanity here
assert self.payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
if self.state == TensorState.FREE:
# free stateful tensor can't change state
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
return
self.__trans_state_update(self.state, state)
if state == TensorState.FREE:
self.__release()
else:
self._state = state
def move_to(self, device: Union[torch.device, int]):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
else:
to_device = device
from_device_type = self.device.type
if from_device_type == to_device.type:
# from device == to device
return
# update manager's information
self.__trans_device_update(from_device_type, to_device.type)
self.payload.data = self.payload.data.to(to_device)
def payload_copy(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def payload_reset(self, tensor) -> None:
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
if self.payload is not None:
# release old payload
self.__trans_state_update(self.state, TensorState.FREE)
else:
# otherwise, set the state to HOLD for new payload
self._state = TensorState.HOLD
del self._payload
self._payload = tensor
self._payload_size = sizeof_tensor(tensor)
# record new payload
self.__trans_state_update(TensorState.FREE, self.state)
def payload_relay(self, rhs):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert not rhs.is_null()
# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert self.payload_size == 0
self._payload = rhs.payload
self._payload_size = rhs.payload_size
self._state = TensorState.HOLD
self.__trans_state_update(rhs.state, TensorState.HOLD)
rhs.__release()
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
@property
def payload_size(self) -> int:
return self._payload_size
@property
def state(self) -> TensorState:
return self._state
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
def __release(self):
# release current payload
# shouldn't be visible to users
self._state = TensorState.FREE
self._payload = None
self._payload_size = 0
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type
if from_state != TensorState.FREE:
manager.state_mem[device_type][from_state] -= size
else:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager.total_mem[device_type] += size
if to_state != TensorState.FREE:
manager.state_mem[device_type][to_state] += size
else:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager.total_mem[device_type] -= size
def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state
# update aggregated information
manager.total_mem[from_type] -= size
manager.total_mem[to_type] += size
# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size
def __del__(self):
self.set_null()
StatefulTensor.GST_MGR.delete_instance()
del self

View File

@@ -1,103 +0,0 @@
import functools
import types
from time import time
from typing import List
import torch
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class StatefulTensorMgr(object):
"""
Stateful Tensor Manager, inspired from PatrickStar
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
self._stateful_tensor_list: List[StatefulTensor] = []
self._compute_list: List[StatefulTensor] = []
self._compute_idx: int = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
self._warmup = True
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
self._stateful_tensor_list = tensor_list
for t in self._stateful_tensor_list:
assert isinstance(t, StatefulTensor)
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
def start_iter(self):
pass
def finish_iter(self):
"""This function must be called when each iteration finishes
"""
self._warmup = False
self._compute_idx = -1
self._cpu_gpu_move_volume = 0
self._layout_time = 0
self._evict_time = 0
def adjust_layout(self) -> None:
""" Adjust the layout of stateful tensor according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
start = time()
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
self._layout_time += time() - start
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
cuda_demand=cuda_demand,
warmup=self._warmup,
compute_list=self._compute_list,
compute_idx=self._compute_idx)
self._cpu_gpu_move_volume += vol
self._evict_time += evict_time
# move COMPUTE tensors to CUDA
self._cpu_gpu_move_volume += cuda_demand
for t in move_to_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, get_current_device())
@property
def cpu_gpu_move_volume(self):
return self._cpu_gpu_move_volume
def _trans_state(self, trans_state_func, stateful_tensor, state):
trans_state_func(state)
if state == TensorState.COMPUTE:
self._compute_idx += 1
if self._warmup:
self._compute_list.append(stateful_tensor)
@functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool):
move_to_cuda_tensor_list = []
hold_cuda_tensor_list = []
for tensor in self._stateful_tensor_list:
if tensor.state == TensorState.FREE:
continue
if tensor.device.type == 'cuda':
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor)
elif tensor.device.type == 'cpu':
if tensor.state == TensorState.COMPUTE:
move_to_cuda_tensor_list.append(tensor)
else:
raise RuntimeError
return move_to_cuda_tensor_list, hold_cuda_tensor_list

View File

@@ -1,139 +0,0 @@
import functools
from abc import ABC, abstractmethod
from time import time
from typing import List, Optional, Type
import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class TensorPlacementPolicy(ABC):
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
self.device: Optional[torch.device] = device
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
@abstractmethod
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
raise NotImplementedError
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
volume = 0
for t in hold_cuda_tensor_list:
colo_model_data_tensor_move_inline(t, self.device)
volume += t.payload.numel() * t.payload.element_size()
return volume, 0
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
return 0, 0
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
super().__init__(None, mem_stats_collector=mem_stats_collector)
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
# TODO(ver217): make these args configurable
self._warmup_non_model_data_ratio: float = 0.8
self._steady_cuda_cap_ratio: float = 0.9
def evict_tensors(self,
hold_cuda_tensor_list: List[StatefulTensor],
cuda_demand: int = 0,
warmup: bool = True,
compute_list: List[StatefulTensor] = [],
compute_idx: int = 0,
**kwargs) -> int:
"""
Evict tensors from CUDA device.
Args:
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
compute_idx (int, optional): the idx of computing device. Defaults to 0.
Raises:
RuntimeError:
Returns:
int: the volume of memory that is evicted
"""
start = time()
cuda_capacity = colo_device_memory_capacity(get_current_device())
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
if warmup:
# We designate a part of CUDA memory for model data in warmup iterations.
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
else:
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
cuda_capacity *= self._steady_cuda_cap_ratio
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
freed_cuda_model_data = 0
end = time()
if avail_cuda_model_data < cuda_demand:
# Move cuda_demand - avail_cuda_model_data volume of tensors
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
to_free_tensor_list = hold_cuda_tensor_list
if not warmup:
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
tuple(compute_list))
# print(self._sort_hold_cuda_tensors.cache_info())
end = time()
for t in to_free_tensor_list:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += t.payload_size
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
if freed_cuda_model_data < to_free_cuda_model_data:
raise RuntimeError(
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
)
return freed_cuda_model_data, end - start
@staticmethod
@functools.lru_cache(maxsize=None)
def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list:
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors}
for i in range(len(compute_list) - 1, compute_idx, -1):
if compute_list[i] in next_compute_idx:
next_compute_idx[compute_list[i]] = i
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
return [t for (t, idx) in next_compute_idx]
class TensorPlacementPolicyFactory:
@staticmethod
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
if policy_name == 'cpu':
return CPUTensorPlacementPolicy
elif policy_name == 'cuda':
return CUDATensorPlacementPolicy
elif policy_name == 'auto':
return AutoTensorPlacementPolicy
else:
raise TypeError(f"Unknown tensor placement policy {policy_name}")

View File

@@ -1,120 +0,0 @@
from typing import Tuple, Union
import torch
from .stateful_tensor import StatefulTensor
def is_storage_empty(tensor: torch.Tensor) -> bool:
return tensor.storage().size() == 0
def free_storage(tensor: torch.Tensor) -> None:
if not is_storage_empty(tensor):
tensor.storage().resize_(0)
def alloc_storage(tensor: torch.Tensor) -> None:
if is_storage_empty(tensor):
tensor.storage().resize_(tensor.numel())
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
else:
return 0, 0
cuda_use, cpu_use = 0, 0
mem_use = t.storage().size() * t.element_size()
if t.device.type == 'cuda':
cuda_use += mem_use
elif t.device.type == 'cpu':
cpu_use += mem_use
return cuda_use, cpu_use
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if isinstance(src_t, StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if isinstance(tgt_t, StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if isinstance(src_t, StatefulTensor):
src_t.set_null()
else:
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a target device, if type is int, it the index of cuda card.
"""
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
if isinstance(t, torch.Tensor):
t.data = t.data.to(target_device)
elif isinstance(t, StatefulTensor):
t.move_to(target_device)
else:
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
# TODO() optimize the tensor moving with non-blocking
if isinstance(t, torch.Tensor):
t.data = t.data.cpu()
elif isinstance(t, StatefulTensor):
t.move_to(torch.device('cpu'))
else:
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
# TODO() rename this function
colo_model_data_tensor_move_inline(t, target_device)
t_payload = t.payload if isinstance(t, StatefulTensor) else t
return t_payload

View File

@@ -1,3 +0,0 @@
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']

View File

@@ -1,270 +0,0 @@
import contextlib
import functools
from contextlib import AbstractContextManager
from dataclasses import dataclass
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass
class ZeroContextConfig:
"""The configuration used to control zero context initialization.
Args:
target_device (torch.device): The device where param data are after exiting the context.
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
Some parameters are not replicated, e.g. parameters in MOE experts.
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
"""
target_device: torch.device
is_replicated: bool = True
shard_param: bool = False
def __post_init__(self):
if self.shard_param:
assert self.is_replicated, "Non-replicated parameters can't be sharded."
if self.is_replicated and not self.shard_param:
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""A context to initialize model.
1. Convert the model to fp16.
2. The parameters of the module are adapted to type ShardedParameter.
3. Shard the param and grad according to flags.
Args:
target_device (torch.device): The device where param data are after exiting the context.
shard_strategy (BaseShardStrategy): Shard strategy instance.
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
def __init__(self,
target_device: torch.device,
shard_strategy: BaseShardStrategy,
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__(default_dtype=default_dtype)
self.shard_strategy = shard_strategy
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
ZeroContextMgr().current_context = self
self.param_numel = {}
self.top_module = None
@property
def target_device(self):
return self.config.target_device
@property
def is_replicated(self):
return self.config.is_replicated
@property
def shard_param(self):
return self.config.shard_param
@staticmethod
def calc_fanin_fanout(tensor: torch.Tensor):
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
This can help us get correct fan-in and fan-out for sharded tensor.
"""
assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters"
# get correct shape of input tensor
if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded:
tensor_shape = tensor.shape
else:
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
dimensions = len(tensor_shape)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
num_input_fmaps = tensor_shape[1]
num_output_fmaps = tensor_shape[0]
receptive_field_size = 1
if dimensions > 2:
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor_shape[2:]:
receptive_field_size *= s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
# substitute fan-in and fan-out calculation
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
self.module_load_from_state_dict = nn.Module._load_from_state_dict
shard_strategy = self.shard_strategy if self.config.shard_param else None
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
shard_strategy=shard_strategy)
self.module_state_dict = nn.Module.state_dict
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
shard_strategy=shard_strategy,
state_dict_func=self.module_state_dict,
process_group=self.dp_process_group)
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
# set new seed for initialization, since we initialize sharded tensor separately
# we don't want all processes have the same seed
# otherwise all sharded tensors are same after init
offset = self.seed + 1 # we want to have more 1 in binary format seed
torch.manual_seed(self.seed + offset * dist.get_rank())
def _post_context_exec(self):
"""The callback function when exiting context.
"""
# broadcast replicated no-shard parameters
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
for param in self.param_list:
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.set_data_none()
del self.param_list
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
nn.Module.load_state_dict = self.module_load_from_state_dict
nn.Module.state_dict = self.module_state_dict
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
params = frozenset(self.top_module.parameters())
for param in self.param_numel.keys():
if param not in params:
self.param_numel[param] = 0
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
def half_fn(t: torch.Tensor):
return t.to(half_dtype) if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'colo_attr'):
continue
self.param_numel[param] = param.numel()
# convert parameters to half
param_half = half_fn(param)
param.data = param_half
if param.grad is not None:
grad_half = half_fn(param.grad)
param.grad.data = grad_half
# move torch parameters to the target device
target_device = self.target_device
param.data = param.data.to(target_device)
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, set_data_none=True)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated
# mark whether the param should keep not sharded
# if True, the param is used as Zero stage 2
param.colo_attr.keep_not_shard = not self.shard_param
self.param_list.append(param)
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_fn(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):
current_context: Optional[ZeroInitContext] = None
@contextlib.contextmanager
def hijack_context_config(self, **kwargs):
if self.current_context is None:
yield
else:
old_config = self.current_context.config
self.current_context.config = ZeroContextConfig(**kwargs)
yield
self.current_context.config = old_config
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
is_replicated=is_replicated,
shard_param=False)
def no_shard_zero_decrator(is_replicated: bool = True):
def _wrapper(init_func):
def _no_shard(*args, **kwargs):
with no_shard_zero_context(is_replicated):
ret = init_func(*args, **kwargs)
return ret
return _no_shard
return _wrapper

View File

@@ -1,5 +0,0 @@
from .base_shard_strategy import BaseShardStrategy
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
from .tensor_shard_strategy import TensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']

View File

@@ -1,22 +0,0 @@
from abc import ABC, abstractmethod
from typing import List, Optional
import torch.distributed as dist
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC):
def __init__(self) -> None:
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
"""
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
pass

View File

@@ -1,47 +0,0 @@
from typing import List, Optional
import torch
import torch.distributed as dist
from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usually small).
"""
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
if len(tensor_list) == 0:
return
target_device = tensor_list[0].device
dtype = tensor_list[0].dtype
buffer_list: List[torch.Tensor] = []
tensor_numels = [t.payload.numel() for t in tensor_list]
buffer_size = sum(tensor_numels)
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
# Move to target device before splitting buffer
# Ensure we utilize maximum PCIE bandwidth
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
offset = 0
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.payload_reset(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@@ -1,22 +0,0 @@
from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(world_size))
while len(chunks) < world_size:
chunks.append(chunks[0].new_empty(0))
# Determine number of padding elements.
num_to_pad = chunks[0].numel() - chunks[rank].numel()
assert num_to_pad >= 0, num_to_pad
shard = torch.zeros_like(chunks[0])
length = chunks[rank].size(0)
shard_temp = shard[:length]
shard_temp.copy_(chunks[rank])
return shard, num_to_pad

View File

@@ -1,59 +0,0 @@
from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.shard_utils.commons import get_shard
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
"""
A naive implementation which shard each tensor evenly over all ranks
"""
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:
self._shard_tensor(t, process_group)
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
for t in tensor_list:
self._gather_tensor(t, process_group)
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
""" Shard tensor among processes.
Args:
t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None.
"""
if t.is_sharded:
return
if t.payload.device.type == 'cuda':
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.payload_reset(sharded_payload)
t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
if not t.is_sharded:
return
target_device = t.device
payload_numel = t.payload.numel()
world_size = dist.get_world_size(process_group)
rank = dist.get_rank(process_group)
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
buffer_list[rank].copy_(t.payload)
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.payload_reset(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@@ -1,3 +0,0 @@
from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2']

View File

@@ -1,85 +0,0 @@
from typing import Any, Callable, List, Tuple, Union
import torch
import torch.nn.functional as F
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:
factor: int = 1
while world_size % factor == 0 and world_size / factor > factor:
factor *= 2
return float(factor)
def free_storage(data: torch.Tensor) -> None:
"""Free underlying storage of a Tensor."""
if data.storage().size() > 0:
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
# is the sole occupant of the Storage.
assert data.storage_offset() == 0
data.storage().resize_(0)
@torch.no_grad()
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
"""Allocate storage for a tensor."""
if data.storage().size() == size.numel(): # no need to reallocate
return
assert data.storage().size() == 0
data.storage().resize_(size.numel())
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.half()
return tensor
def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
return tensor.float()
return tensor
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.bfloat16()
return tensor
def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)
elif isinstance(x, list):
return [apply_to_tensors(t, fn) for t in x]
elif isinstance(x, tuple):
return tuple(apply_to_tensors(t, fn) for t in x)
elif isinstance(x, dict):
return {key: apply_to_tensors(val, fn) for key, val in x.items()}
else:
return x
def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn)
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
chunks = list(torch.flatten(tensor).chunk(num_chunks))
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
if num_pad_for_partial_chunk > 0:
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
if len(chunks) < num_chunks:
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
return chunks

View File

@@ -1,200 +0,0 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
import os
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
enable_nccl_base_collectives = False
else:
enable_nccl_base_collectives = True
class Bucket:
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
self.output_shard = torch.zeros_like(self.buffer[0])
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
dist._reduce_scatter_base(self.output_shard[:self.offset],
self.buffer[:, :self.offset].contiguous(),
group=self.group)
else:
dist.reduce_scatter(self.output_shard[:self.offset],
list(self.buffer[:, :self.offset].unbind(0)),
group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.buffer[:, :self.offset].zero_()
self.offset = 0
self.callbacks.clear()
self.output_shard = torch.zeros_like(self.buffer[0])
def alloc(self) -> None:
"""Setup the buffers if they are not allocated.
Using ``setup`` and ``teardown``, we can ensure that the bucket
buffers are only allocated during the backward pass, hence saving more
memory to other parts of the training process, such as the forward pass
for activation memory.
"""
for tensor in [self.buffer, self.output_shard]:
if tensor.storage().size() == 0:
tensor.storage().resize_(tensor.size().numel())
def free(self) -> None:
"""Tear down the bucket by freeing the memory"""
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
for tensor in [self.buffer, self.output_shard]:
tensor.storage().resize_(0)
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
# copy data from input_list into bucket
tensor_size = tensor_list[0].numel()
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
offset = self.offset
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
self.callbacks.append(functools.partial(callback_fn, result_view))
class ReduceScatterBucketer:
"""
Helper for bucketing multiple reduce-scatter operations on small tensors
into larger reduce-scatter ops to improve communication efficiency.
Usage::
bucketer = ReduceScatterBucketer()
bucketer.reduce_scatter_async(
small_tensors, callback_fn=lambda result: print("small")
)
bucketer.reduce_scatter_async(
big_tensors, callback_fn=lambda result: print("big")
)
bucketer.reduce_scatter_async(
more_small_tensors, callback_fn=lambda result: print("small2")
)
bucketer.flush() # callbacks only guaranteed to be called after flush()
# Example output (note that it is out of order, due to bucketing):
# big
# small
# small2
Args:
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
are sub-divided based on world_size. Values <= 0 disable bucketing.
"""
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def reduce_scatter_async(
self,
input_list: List[Tensor],
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
"""
Reduce-scatter a list of tensors asynchronously, so smaller reductions
can be bucketed together. The given callback (``callback_fn``) will be
called with the reduced result at some later time. Call ``flush()`` to
force all queued ops and callbacks to be executed.
Note that large inputs will be reduced immediately, and this function
may also flush the relevant bucket to make room for ``input_list``.
Args:
input_list (List[Tensor]): list of tensors to reduce-scatter. List
should contain ``group.size()`` tensors and each tensor should
have identical shape, dtype and device.
group (ProcessGroup): process group for reduction
callback_fn (Callable, Optional): callback function to call after
the reduction executes. Function will be called with a single
argument corresponding to the reduced result.
"""
world_size = group.size()
assert (len(input_list) == world_size
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
first_input = input_list[0]
first_input_size = first_input.numel()
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
if first_input_size > bucket_shard_size:
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
# input is too big to fit in the bucket, reduce-scatter directly
output = torch.zeros_like(input_list[0])
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
input_flattened = torch.cat(input_list)
dist._reduce_scatter_base(output, input_flattened, group=group)
else:
# fallback
dist.reduce_scatter(output, input_list, group=group)
if callback_fn is not None:
callback_fn(output)
return
bucket = self._get_bucket(first_input, group)
if first_input_size > bucket.buffer.size(1) - bucket.offset:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(input_list, callback_fn)
@torch.no_grad()
def flush(self) -> None:
"""Reduce-scatter any partial buckets."""
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def free(self) -> None:
"""Free buffers from all buckets."""
for bucket in self.buckets.values():
bucket.free()
@functools.lru_cache()
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size // num_shards)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
world_size = group.size()
shard_size = self._get_shard_size(tensor.element_size(), world_size)
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]

View File

@@ -1,577 +0,0 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import functools
import itertools
from collections import OrderedDict
from copy import deepcopy
from typing import Any, Iterator, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
from ._utils import (
cast_float_arguments,
cast_tensor_to_bf16,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
free_storage,
get_gradient_predivide_factor,
)
from .zero_hook import ZeroHook
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class ShardedModelV2(nn.Module):
"""
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward
passes can be executed with limited CUDA memory budget.
Note:
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
Note:
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
if you enable ``reuse_fp16_shard``.
Args:
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
Defaults to 'cuda'.
gradient_predivide_factor (Optional[float], optional): Gradient is divided by this value before reduce-scatter. Defaults to 1.0.
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
module: nn.Module,
shard_strategy: BaseShardStrategy,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False,
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
# We force users to use ZeroInitContext
for submodule in module.modules():
sharded_cnt = 0
unshard_cnt = 0
for param in submodule.parameters(recurse=False):
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
if param.colo_attr.param_is_sharded:
sharded_cnt += 1
else:
unshard_cnt += 1
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
submodule.param_is_sharded = (sharded_cnt > 0)
self.sharded_params = []
self.unshard_params = []
for param in module.parameters():
if param.colo_attr.param_is_sharded:
self.sharded_params.append(param)
else:
self.unshard_params.append(param)
self.module = module
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
self.world_size = dist.get_world_size(self.process_group)
self.rank = dist.get_rank(self.process_group)
self.shard_strategy = shard_strategy
self._use_memory_tracer = tensor_placement_policy == 'auto'
if self._use_memory_tracer:
self._memstats_collector = MemStatsCollector()
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
else:
self._memstats_collector = None
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
if 'warmup_non_model_data_ratio' in kwargs:
if tensor_placement_policy != 'auto':
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
else:
ratio = kwargs['warmup_non_model_data_ratio']
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
# Register hooks
self._ophook_list = [
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
]
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
self.fp32_reduce_scatter = fp32_reduce_scatter
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
for param in module.parameters():
# Init `offload_grad`
param.colo_attr.offload_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set
# gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if \
gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
self._require_backward_grad_sync: bool = True
self._cuda_margin_space = 0
self.reuse_fp16_shard = reuse_fp16_shard
# record whether gradients have inf or nan
self.overflow_counter = 0
def adjust_stateful_tensor_layout(self) -> None:
self._stateful_tensor_mgr.adjust_layout()
@property
def use_memory_tracer(self):
return self._use_memory_tracer
@property
def cuda_margin_space(self):
return self._cuda_margin_space
@property
def cpu_offload(self):
return self._cpu_offload
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
"""
dummy memory tracer collected information to a file.
try:
# forward: model(inputs)
# backward: optimizer.backward()
except Exception as e:
model.dump_memory_stats()
exit(0)
"""
if self._use_memory_tracer:
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
if gpc.get_global_rank() == 0:
with open(filename, 'w+') as f:
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
f.write('CPU non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
f.write('\n')
def _pre_forward_operations(self, *args):
# the operation will affect the memory tracer behavior in ZeroHook
if self._memstats_collector:
self._start_collect_memstats()
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
self._stateful_tensor_mgr.start_iter()
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'colo_attr'):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations(*args)
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs
def backward(self, loss):
loss.backward()
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
def _update_memstats(self):
if self._memstats_collector:
self._finish_collect_memstats()
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
@torch.no_grad()
def _post_backward_operations(self) -> None:
"""
The method includes operations required to be processed after backward
1. update memory tracer.
2. flush the gradient in buckets. Reducing partial gradients in each process.
3. shard tensors not dealed in the zero hook
4. move sharded param grad payload to param.grad
"""
# 1. update memory tracer.
self._update_memstats()
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
if self._require_backward_grad_sync:
# Flush any unreduced buckets in the post_backward stream.
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream)
self.reducer.free()
# 3. shard tensors not dealed in the zero hook
tensor_list = []
for p in self.sharded_params:
if not p.colo_attr.param_is_sharded:
tensor_list.append(p.colo_attr.sharded_data_tensor)
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.colo_attr.set_data_none()
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. set all parameters' grad to None
for p in self.module.parameters():
if not p.requires_grad:
continue
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group.
# If _require_backward_grad_sync is True,
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync:
continue
p.grad = None
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.colo_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
return
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
if not self._require_backward_grad_sync:
return
# used to cheat Pytorch, since we can't return None
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
# As torch didn't allow modifying grad in hook, we make a copy
grad = grad.clone()
if param.colo_attr.is_replicated:
self._reduce_scatter_handler(param, grad)
else:
self._save_grad(param, grad)
return empty_grad
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream):
if self.fp32_reduce_scatter:
grad.data = grad.data.to(param.dtype)
if self.gradient_predivide_factor > 1.0:
# Average grad by world_size for consistency with PyTorch DDP.
grad.data.div_(self.gradient_predivide_factor)
if self.world_size > 1:
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, grad)
torch.cuda.current_stream().wait_stream(self.comm_stream)
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
self._save_grad(param, reduced_grad)
# FIXME(ver217): refactor the below line when impl eviction policy
def _save_grad(self, param: Parameter, grad: torch.Tensor):
# record whether we have overflow
self.overflow_counter += torch.isinf(grad).any().item()
self.overflow_counter += torch.isnan(grad).any().item()
# move gradient to cpu
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad)
if self.reuse_fp16_shard:
# make parameters point to gradient
assert param.colo_attr.saved_grad.is_null(
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
# we set a false None for parameter's payload
# so we can get parameter's device and dtype later in optimizer
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.grad_payload_reset(fp32_grad)
else:
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
return self.module.parameters(recurse=recurse)
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
return self.module.named_parameters(prefix, recurse)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
return self._colo_state_dict(destination,
prefix,
keep_vars,
shard_strategy=self.shard_strategy,
state_dict_func=nn.Module.state_dict,
module_to_load=self.module,
sharded_params=self.sharded_params,
process_group=self.process_group)
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
for name, p in self.named_parameters():
if name in state_dict:
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
device=p.colo_attr.data_payload.device))
# Force re-shard
p.colo_attr.sharded_data_tensor.is_sharded = False
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
elif strict:
raise RuntimeError(f'Missing key in state_dict: {name}')
def _colo_state_dict(self,
destination=None,
prefix='',
keep_vars=False,
shard_strategy: Optional[BaseShardStrategy] = None,
state_dict_func=None,
module_to_load=None,
sharded_params=[],
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
if len(sharded_params) == 0:
for param in self.parameters():
if param.colo_attr.param_is_sharded:
sharded_params.append(param)
if shard_strategy is not None:
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.data = p.colo_attr.data_payload
module_to_load = module_to_load or self
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
if shard_strategy is not None:
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
for p in sharded_params:
p.colo_attr.set_data_none()
return gathered_state_dict
def _colo_load_from_state_dict(self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
shard_strategy=None):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None.
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if hasattr(param, 'colo_attr'):
param.colo_attr.data_payload_reset(
input_param.to(dtype=param.colo_attr.data_payload.dtype,
device=param.colo_attr.data_payload.device))
if shard_strategy is not None:
# Force re-shard
param.colo_attr.sharded_data_tensor.is_sharded = False
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
else:
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(
key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def __getitem__(self, idx: int):
assert isinstance(self.module, nn.ModuleList)
return self.module[idx]
def __len__(self):
assert isinstance(self.module, nn.ModuleList)
return len(self.module)
def __iter__(self):
assert isinstance(self.module, nn.ModuleList)
return iter(self.module)

View File

@@ -1,20 +0,0 @@
import copy
import torch
from colossalai.zero.legacy.sharded_model import ShardedModelV2
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
"""
copy param of the ShardedModelV2 to other_model.
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'colo_attr')
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])

View File

@@ -1,118 +0,0 @@
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.legacy.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
@OPHOOKS.register_module
class ZeroHook(BaseOpHook):
"""
A hook to process sharded param for ZeRO method.
Warning: this class has been deprecated after version 0.1.12
"""
def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector] = None,
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
process_group: Optional[dist.ProcessGroup] = None):
super().__init__()
self.logger = get_dist_logger("ZeROHook")
self.shard_strategy = shard_strategy
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = get_current_device()
self._memstarts_collector = memstarts_collector
self._stateful_tensor_mgr = stateful_tensor_mgr
def gather_parameters(self, module: torch.nn.Module):
# gather sharded parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
def shard_parameters(self, module: torch.nn.Module):
# shard gathered parameters
if module.param_is_sharded:
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
def adjust_module_data(self, module: torch.nn.Module):
# record overall data statistics
if self._memstarts_collector:
self._memstarts_collector.sample_overall_data()
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
# adjust stateful tensor to get enough CUDA memory
self._stateful_tensor_mgr.adjust_layout()
# record model data statistics
if self._memstarts_collector:
self._memstarts_collector.record_model_data_volume()
def pre_fwd_exec(self, module: torch.nn.Module, *args):
self.adjust_module_data(module)
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
# change tensor state to HOLD_AFTER_FWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
self.shard_parameters(module)
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.set_data_none()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
self.adjust_module_data(module)
self.gather_parameters(module)
for param in module.parameters(recurse=False):
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
# change tensor state to HOLD_AFTER_BWD
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
self.shard_parameters(module)
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.set_data_none()
def pre_iter(self):
pass
def post_iter(self):
if self._stateful_tensor_mgr:
self.logger.debug(
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
ranks=[0])
self._stateful_tensor_mgr.finish_iter()

View File

@@ -1,3 +0,0 @@
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2']

View File

@@ -1,399 +0,0 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
from enum import Enum
from os import stat
from typing import Dict, Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
class OptimState(Enum):
SCALED = 1
UNSCALED = 2
class ShardedOptimizerV2(ColossalaiOptimizer):
"""A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO).
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
which is detected by a runtime memory tracer.
We place as many OS chunks in the margin space as possible.
The size of margin space can be controlled by ``gpu_margin_mem_ratio``.
If it is set as ``0.0``, it is the same as classical ZeRO optimizer.
Note:
You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``.
Note:
Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`,
if you set ``gpu_margin_mem_ratio > 0``.
Args:
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
shard strategy provided by sharded model to shard param fp32 tensors.
optimizer (Optimizer): An Optimizer instance.
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
which will be used when using hybrid CPU optimizer.
This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
Defaults to 0.0.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
dp_process_group (Optional[ProcessGroup], optional): data parallel process group. Defaults to None.
mp_process_group (Optional[ProcessGroup], optional): model parallel process group. Defaults to None.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def __init__(self,
sharded_model: ShardedModelV2,
optimizer: Optimizer,
gpu_margin_mem_ratio: float = 0.0,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None,
verbose: bool = False) -> None:
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
self.bf16 = sharded_model.bf16
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
# and it must set `num_fp32_shards_per_param` correctly
self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
optimizer, 'num_fp32_shards_per_param', 0) >= 2
self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
self.optim_state: OptimState = OptimState.UNSCALED
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
# Store fp32 param shards
self._register_master_weight()
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
AutoTensorPlacementPolicy):
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
ranks=[0])
if self._verbose:
self._logger.debug(
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
self._use_memory_tracer = self.model.use_memory_tracer
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
def get_memory_usage(self) -> Tuple[int, int]:
""" Get the memory usage of the optimizer. Including master_params (param fp32),
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
Returns:
Tuple[int, int]: cuda/cpu memory usage in Byte.
"""
cuda_use = 0
cpu_use = 0
def update_mem_use(t):
nonlocal cuda_use
nonlocal cpu_use
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
cuda_use += t_cuda_use
cpu_use += t_cpu_use
for _, p_fp32 in self.master_params.items():
update_mem_use(p_fp32)
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for k, v in state.items():
update_mem_use(v)
return cuda_use, cpu_use
def zero_grad(self, *args, **kwargs):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
if not self.bf16:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
# This function is called except the last stage of pipeline parallel
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
if not self.bf16:
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
# unscale grads if scaled
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
self._maybe_move_fp32_shards()
if not self.bf16:
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
self._point_param_fp16_to_master_param()
if self._verbose:
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
ret = self.optim.step(*args, **kwargs)
if self._verbose:
gpu_mem, cpu_mem = self.get_memory_usage()
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
ranks=[0])
self._copy_master_model_to_model_fp16()
return ret
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.model.overflow_counter)
# all-reduce across dp group
dist.all_reduce(self._found_overflow, group=self.dp_process_group)
# all-reduce over model parallel group
dist.all_reduce(self._found_overflow, group=self.mp_process_group)
return self._found_overflow.item() > 0
def _unscale_grads(self):
assert self.optim_state == OptimState.SCALED
for group in self.optim.param_groups:
for p in group['params']:
if p.grad is not None:
p.grad.data.div_(self.loss_scale)
self.optim_state = OptimState.UNSCALED
def _zero_grad(self, recover_data: bool = False):
"""zero grad and maybe recover fp16 params
When `reuse_fp16_shard` is enabled,
p.colo_attr.sharded_data_tensor stores grad here.
We have to recover them from fp32 params.
Args:
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
"""
# We must set grad to None
# Because grad here is sharded
# But next backward pass will create a full grad first
# Which leads to wrong accumulation
self.optim.zero_grad(set_to_none=True)
for group in self.optim.param_groups:
for p in group['params']:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
if recover_data and reuse_fp16_shard:
self._copy_master_param_to_param_fp16(p)
else:
# release saved gradient
p.colo_attr.saved_grad.set_null()
self.model.overflow_counter = 0 # set overflow counter to zero
def sync_grad(self):
pass
def _register_master_weight(self):
self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
if shard_flag:
# we always shard replicated parameters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
def _maybe_move_fp32_shards(self):
if self._should_move_fp32_shards_h2d:
self._should_move_fp32_shards_h2d = False
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
continue
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
p.colo_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
state = self.optim.state[p]
for k, v in state.items():
if isinstance(v, Tensor):
state[k] = v.cuda()
def _prepare_grads(self):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
continue
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
# If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU
if not p.colo_attr.offload_grad:
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful information
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.grad_payload
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.set_data_none()
self._grad_prepared = True
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
self.master_params[p].trans_state(TensorState.COMPUTE)
p.data = self.master_params[p].payload
# Now p.data is sharded
# So optimizer states are sharded naturally
def _copy_master_model_to_model_fp16(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transferring
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
self._copy_master_param_to_param_fp16(p)
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
if p.colo_attr.sharded_data_tensor.payload_size == 0:
# here reuse_fp16_shard is True
# in order to use copy below, we should give sharded data tensor a payload
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
else:
p.colo_attr.saved_grad.set_null()
p.data = self.master_params[p].payload
# we need to allocate new memory for keep_not_shard parameters
# in order to use copy, otherwise, the sizes of tensor is not compatible
if p.colo_attr.data_payload.numel() != p.data.numel():
p.colo_attr.data_payload_reset(
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p].trans_state(TensorState.HOLD)
def state_dict(self):
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict
return optim_state_dict
def load_state_dict(self, *args, **kwargs):
if 'scaler' not in args[0]:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else:
scaler_state_dict = args[0].pop('scaler')
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs)
for group in self.optim.param_groups:
for p in group['params']:
state = self.optim.state[p]
for k, v in state.items():
if isinstance(v, Tensor):
state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device)

View File

@@ -1,4 +0,0 @@
from .sharded_param import ShardedParamV2
from .sharded_tensor import ShardedTensor
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@@ -1,110 +0,0 @@
from typing import List, Optional, Tuple
import torch
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
from .sharded_tensor import ShardedTensor
EMPTY_TENSOR_DICT = {}
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
key = (device, dtype)
if key not in EMPTY_TENSOR_DICT:
EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device)
return EMPTY_TENSOR_DICT[key]
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
self.offload_grad: bool = False
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.
# For example: File "resnet.py", line 190, in __init__
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self.param = param
if set_data_none:
self.set_data_none()
def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor]
def set_data_none(self):
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
def set_grad_none(self):
self.saved_grad.set_null()
@property
def sharded_data_tensor(self):
return self._sharded_data_tensor
@property
def data_payload(self):
assert not self.sharded_data_tensor.is_null()
return self.sharded_data_tensor.payload
@property
def grad_payload(self):
assert not self.saved_grad.is_null()
return self.saved_grad.payload
@property
def param_is_sharded(self):
return self.sharded_data_tensor.is_sharded
def data_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.payload_reset(tensor)
def grad_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.payload_reset(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""
get the memory usage of the param, including data and grad
Returns:
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
"""
cuda_mem_use, cpu_mem_use = 0, 0
def _update_mem_use(t: Optional[torch.Tensor]):
if t is None:
return
assert isinstance(t, torch.Tensor)
nonlocal cuda_mem_use
nonlocal cpu_mem_use
t_cuda, t_cpu = colo_tensor_mem_usage(t)
cuda_mem_use += t_cuda
cpu_mem_use += t_cpu
address_set = set()
_update_mem_use(self.data_payload)
address_set.add(self.data_payload.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.grad_payload)
address_set.add(self.saved_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
_update_mem_use(self.param.data)
address_set.add(self.param.data.data_ptr())
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
_update_mem_use(self.param.grad)
return cuda_mem_use, cpu_mem_use

View File

@@ -1,40 +0,0 @@
import torch
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert tensor.requires_grad is False
super().__init__(tensor, state)
# kept the shape, numel and dtype of the init tensor.
self._origin_shape = tensor.shape
self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype
self._is_sharded = False
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
@property
def origin_numel(self) -> int:
return self._origin_numel
@property
def origin_shape(self) -> int:
return self._origin_shape
@property
def is_sharded(self):
return self._is_sharded
@is_sharded.setter
def is_sharded(self, flag: bool):
self._is_sharded = flag

View File

@@ -7,9 +7,6 @@ from torch import Tensor, inf
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ProcessGroup
from colossalai.tensor import ColoParameter
from colossalai.utils import is_model_parallel_parameter
def flatten(input_):
return _flatten_dense_tensors(input_)