mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-31 16:40:41 +00:00
[legacy] clean up legacy code (#4743)
* [legacy] remove outdated codes of pipeline (#4692) * [legacy] remove cli of benchmark and update optim (#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (#4696) * [legacy] clean up utils (#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci
This commit is contained in:
@@ -3,7 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
|
||||
from colossalai.legacy.tensor import ProcessGroup
|
||||
from colossalai.tensor import ColoParameter, ColoTensor
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
|
||||
# find named_params includes replica
|
||||
|
@@ -3,9 +3,8 @@ from .memory_stats import MemStats # isort:skip
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'MemStats', 'OrderedParamGenerator'
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', 'MemStats',
|
||||
'OrderedParamGenerator'
|
||||
]
|
||||
|
@@ -1,7 +1,6 @@
|
||||
from typing import Optional
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import ChunkManager
|
||||
|
||||
from .memory_stats import MemStats
|
||||
@@ -33,4 +32,5 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda
|
||||
|
@@ -5,7 +5,7 @@ from time import sleep, time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import colo_device_memory_used, get_current_device
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
@@ -110,6 +110,7 @@ class AsyncMemoryMonitor(MemoryMonitor):
|
||||
return max_usage
|
||||
|
||||
def _measure_usage(self):
|
||||
from colossalai.legacy.utils import colo_device_memory_used
|
||||
max_usage = 0
|
||||
while self.keep_measuring:
|
||||
max_usage = max(
|
||||
|
@@ -70,7 +70,7 @@ class MemStatsCollector:
|
||||
Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
from colossalai.zero.legacy.gemini import StatefulTensor
|
||||
from colossalai.legacy.zero.gemini import StatefulTensor
|
||||
|
||||
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
|
@@ -1,12 +1,12 @@
|
||||
import torch.nn
|
||||
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float
|
||||
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||
from colossalai.legacy.zero.gemini.ophooks.runtime_mem_tracer_hook import (
|
||||
GradMemStats,
|
||||
GradMemTracerHook,
|
||||
ParamMemTracerHook,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float
|
||||
|
||||
from .memory_stats import MemStats
|
||||
|
||||
|
@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.chunk import Chunk
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
|
@@ -1,45 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
from .shard_utils import BucketTensorShardStrategy, TensorShardStrategy
|
||||
from .sharded_model import ShardedModelV2
|
||||
from .sharded_optim import ShardedOptimizerV2
|
||||
|
||||
|
||||
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
|
||||
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
|
||||
"""
|
||||
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
|
||||
|
||||
:param model: Your model object
|
||||
:type model: :class:`torch.nn.Module`
|
||||
:param optimizer_config: Your optimizer object
|
||||
:type optimizer_config: :class:`dict`
|
||||
|
||||
:return: (model, optimizer)
|
||||
:rtype: Tuple
|
||||
"""
|
||||
|
||||
logger = get_dist_logger('convert_to_zero_v2')
|
||||
|
||||
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
|
||||
if optimizer_config is None:
|
||||
optimizer_config = dict()
|
||||
logger.info(f'model_config is {model_config}', ranks=[0])
|
||||
if model_config is None:
|
||||
model_config = dict()
|
||||
|
||||
zero_model = ShardedModelV2(model, **model_config)
|
||||
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
|
||||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = [
|
||||
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
|
||||
'no_shard_zero_decrator', 'TensorShardStrategy', 'BucketTensorShardStrategy'
|
||||
]
|
@@ -1,9 +0,0 @@
|
||||
from .ophooks import BaseOpHook, register_ophooks_recursively
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
|
||||
|
||||
__all__ = [
|
||||
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
|
||||
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
|
||||
]
|
@@ -1,48 +0,0 @@
|
||||
from enum import EnumMeta
|
||||
|
||||
|
||||
class GeminiMemoryManager(object):
|
||||
|
||||
def __init__(self, states_cls: EnumMeta):
|
||||
super().__init__()
|
||||
self.states_cls = states_cls
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem = dict()
|
||||
self.state_mem = dict()
|
||||
self.state_mem['cpu'] = dict()
|
||||
self.state_mem['cuda'] = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def total_number(self):
|
||||
return self._cnter
|
||||
|
||||
def reset(self):
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
|
||||
# memory conditions for all states
|
||||
for state in self.states_cls:
|
||||
self.state_mem['cpu'][state] = 0
|
||||
self.state_mem['cuda'][state] = 0
|
||||
|
||||
def register_new_instance(self):
|
||||
self._cnter += 1
|
||||
|
||||
def delete_instance(self):
|
||||
self._cnter -= 1
|
||||
|
||||
def print_info(self):
|
||||
print(f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n",
|
||||
sep='\n')
|
||||
|
||||
for state in self.states_cls:
|
||||
print(f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n",
|
||||
sep='\n')
|
@@ -1,3 +0,0 @@
|
||||
from .utils import BaseOpHook, register_ophooks_recursively
|
||||
|
||||
__all__ = ["BaseOpHook", "register_ophooks_recursively"]
|
@@ -1,32 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardGradMemTracerHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and after FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, '_sharded_grad')
|
||||
param._sharded_grad.setup()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
@@ -1,48 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardParamHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and after FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def niter(self):
|
||||
return self._niter
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
@@ -1,145 +0,0 @@
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradMemStats():
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
def clear(self):
|
||||
self.unreleased_grad_flag.clear()
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
class GradMemTracerHook():
|
||||
|
||||
def __init__(self, grad_stats: GradMemStats):
|
||||
self.grad_hook_list = []
|
||||
self._grad_stats = grad_stats
|
||||
|
||||
def grad_handle(self, p, grad):
|
||||
assert self._grad_stats.unreleased_grad_flag[p]
|
||||
free_storage(grad)
|
||||
self._grad_stats.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def register_grad_hook(self, module: torch.nn.Module):
|
||||
for p in module.parameters():
|
||||
if p.requires_grad:
|
||||
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||
self._grad_stats.unreleased_grad_flag[p] = False
|
||||
|
||||
def remove_grad_hook(self):
|
||||
for hook in self.grad_hook_list:
|
||||
hook.remove()
|
||||
|
||||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._memstats = memstats
|
||||
self._grad_stats = gradstats
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
|
||||
def _free_cuda_params(self, params):
|
||||
for p in params:
|
||||
if p.data.device.type == "cpu":
|
||||
raise NotImplementedError("Only free cuda memory")
|
||||
free_storage(p.data)
|
||||
|
||||
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
|
||||
"""
|
||||
move params to cuda
|
||||
|
||||
Args:
|
||||
params (List[torch.nn.Parameter]): target params
|
||||
|
||||
Raises:
|
||||
NotImplementedError: raise error when param has cpu grad
|
||||
"""
|
||||
for p in params:
|
||||
cur_dev = p.data.device.type
|
||||
if cur_dev == "cpu":
|
||||
if p.grad is not None and p.grad.device.type == "cpu":
|
||||
raise NotImplementedError("Only run in forward propagation")
|
||||
p.data = torch.empty(p.data.shape,
|
||||
device="cuda",
|
||||
dtype=p.data.dtype,
|
||||
requires_grad=p.data.requires_grad)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
def record_model_data_volume(self, params):
|
||||
"""
|
||||
get cuda model data used by params
|
||||
"""
|
||||
data_volume = self._grad_stats.unreleased_grad_volume
|
||||
for p in params:
|
||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||
data_volume += cur_model_data_volume
|
||||
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
|
||||
# add param.grad, actually param.grad is None in this time
|
||||
data_volume += cur_model_data_volume
|
||||
if not self._grad_stats.unreleased_grad_flag[p]:
|
||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||
self._grad_stats.unreleased_grad_flag[p] = True
|
||||
# record max non model data used for this Op
|
||||
self._memstats.record_max_cuda_model_data(data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
max_cuda_used_pre_op = self.mem_monitor.finish()
|
||||
# record max cuda overall data for prev OP.
|
||||
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
|
||||
# record max cuda non model data for prev OP.
|
||||
self._memstats.calc_max_cuda_non_model_data()
|
||||
|
||||
self._allocate_params_on_cuda(params)
|
||||
# record max cuda model data for current OP
|
||||
self.record_model_data_volume(params)
|
||||
|
||||
self.mem_monitor.start()
|
||||
self._memstats.increase_preop_step(params)
|
||||
|
||||
def post_op(self, params):
|
||||
self._free_cuda_params(params)
|
||||
|
||||
def pre_forward(self, params: List[torch.Tensor]) -> None:
|
||||
self.pre_op(params)
|
||||
|
||||
def post_forward(self, params: List[torch.Tensor]) -> None:
|
||||
self.post_op(params)
|
||||
|
||||
def pre_backward(self, params: List[torch.Tensor]) -> None:
|
||||
self.pre_op(params)
|
||||
|
||||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
self.post_op(params)
|
||||
|
||||
@contextmanager
|
||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||
old_training_phase = self._training_phase
|
||||
try:
|
||||
self._training_phase = training_phase
|
||||
yield
|
||||
finally:
|
||||
self._training_phase = old_training_phase
|
||||
|
||||
switch_to_backward = switch_training_phase
|
||||
switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD)
|
@@ -1,142 +0,0 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseOpHook(ABC):
|
||||
"""This class allows users to add customized operations
|
||||
before and after the execution of a PyTorch submodule"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
||||
# apply torch.autograd.Function that calls a backward_function to tensors in output
|
||||
def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
if type(outputs) is tuple:
|
||||
touched_outputs = []
|
||||
for output in outputs:
|
||||
touched_output = _apply_to_tensors_only(module, functional, backward_function, output)
|
||||
touched_outputs.append(touched_output)
|
||||
return tuple(touched_outputs)
|
||||
elif type(outputs) is torch.Tensor:
|
||||
return functional.apply(module, backward_function, outputs)
|
||||
else:
|
||||
return outputs
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
module.applied_pre_backward = False
|
||||
outputs = outputs.detach()
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
output = output.detach()
|
||||
ctx.pre_backward_function = pre_backward_function
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
"""
|
||||
Args:
|
||||
activation_grad of the next layer.
|
||||
Returns:
|
||||
grad of the input activation.
|
||||
"""
|
||||
ctx.pre_backward_function(ctx.module)
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
register_ophooks_recursively(child, ophook_list, name + child_name, filter_fn)
|
||||
|
||||
# Early return on modules with no parameters.
|
||||
if len(list(module.parameters(recurse=False))) == 0:
|
||||
return
|
||||
|
||||
# return from filtered module
|
||||
if filter_fn is not None and filter_fn(module):
|
||||
return
|
||||
|
||||
def _pre_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_fwd_exec(submodule, *args)
|
||||
|
||||
def _post_forward_module_hook(submodule, *args):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.pre_bwd_exec(submodule, inputs, output)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
hook.post_bwd_exec(submodule, inputs)
|
||||
|
||||
return _apply_to_tensors_only(submodule, PostBackwardFunction, _run_after_backward_function, inputs)
|
||||
|
||||
module.register_forward_pre_hook(_pre_forward_module_hook)
|
||||
module.register_forward_hook(_post_forward_module_hook)
|
||||
|
||||
module.register_forward_hook(_pre_backward_module_hook)
|
||||
module.register_forward_pre_hook(_post_backward_module_hook)
|
@@ -1,3 +0,0 @@
|
||||
from ._param_hookmgr import BaseParamHookMgr
|
||||
|
||||
__all__ = ["BaseParamHookMgr"]
|
@@ -1,39 +0,0 @@
|
||||
import functools
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseParamHookMgr(object):
|
||||
|
||||
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
register backward hook on every parameters of module
|
||||
"""
|
||||
self._param_list = param_list
|
||||
self._hook_list = []
|
||||
|
||||
def register_backward_hooks(self, hook_call: Callable) -> None:
|
||||
r"""
|
||||
The hook_call will be called every time a gradient with respect to the a param in self.param_list
|
||||
is computed.
|
||||
The hook should have the following signature:
|
||||
```
|
||||
hook(param, grad) -> Tensor or None
|
||||
```
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
||||
handle = p.register_hook(functools.partial(hook_call, p))
|
||||
p._base_param_hook = handle
|
||||
|
||||
def remove_hooks(self) -> None:
|
||||
"""
|
||||
Remove hooks from model parameters.
|
||||
"""
|
||||
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
||||
p._base_param_hook.remove()
|
@@ -1,209 +0,0 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .gemini_context import GeminiMemoryManager
|
||||
|
||||
|
||||
def sizeof_tensor(tensor: torch.Tensor):
|
||||
return tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
# Global Stateful Tensor Manager
|
||||
GST_MGR = GeminiMemoryManager(TensorState)
|
||||
|
||||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = None
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
|
||||
StatefulTensor.GST_MGR.register_new_instance()
|
||||
|
||||
if self._state == TensorState.FREE:
|
||||
# when the state is free, payload should be None
|
||||
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
|
||||
else:
|
||||
# otherwise, payload should not be None
|
||||
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
|
||||
self._payload = maybe_tensor
|
||||
self._payload_size = sizeof_tensor(maybe_tensor)
|
||||
self.__trans_state_update(TensorState.FREE, state)
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return self._payload.data_ptr()
|
||||
|
||||
def set_null(self) -> None:
|
||||
# notice that free stateful tensor do not need to become null again
|
||||
if self.state != TensorState.FREE:
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
self.__release()
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self.state == TensorState.FREE:
|
||||
# check sanity here
|
||||
assert self.payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
if self.state == TensorState.FREE:
|
||||
# free stateful tensor can't change state
|
||||
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
|
||||
return
|
||||
|
||||
self.__trans_state_update(self.state, state)
|
||||
|
||||
if state == TensorState.FREE:
|
||||
self.__release()
|
||||
else:
|
||||
self._state = state
|
||||
|
||||
def move_to(self, device: Union[torch.device, int]):
|
||||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||
|
||||
if not isinstance(device, torch.device):
|
||||
to_device = torch.device('cuda', device)
|
||||
else:
|
||||
to_device = device
|
||||
|
||||
from_device_type = self.device.type
|
||||
if from_device_type == to_device.type:
|
||||
# from device == to device
|
||||
return
|
||||
|
||||
# update manager's information
|
||||
self.__trans_device_update(from_device_type, to_device.type)
|
||||
self.payload.data = self.payload.data.to(to_device)
|
||||
|
||||
def payload_copy(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def payload_reset(self, tensor) -> None:
|
||||
|
||||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||
|
||||
if self.payload is not None:
|
||||
# release old payload
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
else:
|
||||
# otherwise, set the state to HOLD for new payload
|
||||
self._state = TensorState.HOLD
|
||||
del self._payload
|
||||
|
||||
self._payload = tensor
|
||||
self._payload_size = sizeof_tensor(tensor)
|
||||
# record new payload
|
||||
self.__trans_state_update(TensorState.FREE, self.state)
|
||||
|
||||
def payload_relay(self, rhs):
|
||||
# relay the payload of rhs to current stateful tensor
|
||||
# can't support null relay right now
|
||||
assert not rhs.is_null()
|
||||
|
||||
# now this function only support stateful tensor that has zero-length payload
|
||||
# because it doesn't require memory manager updating
|
||||
# you can extend this function by yourself
|
||||
assert self.payload_size == 0
|
||||
|
||||
self._payload = rhs.payload
|
||||
self._payload_size = rhs.payload_size
|
||||
self._state = TensorState.HOLD
|
||||
self.__trans_state_update(rhs.state, TensorState.HOLD)
|
||||
|
||||
rhs.__release()
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
@property
|
||||
def payload_size(self) -> int:
|
||||
return self._payload_size
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
|
||||
|
||||
def __release(self):
|
||||
# release current payload
|
||||
# shouldn't be visible to users
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
self._payload_size = 0
|
||||
|
||||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||
"""Update global manager when changing the state of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
device_type = self.device.type
|
||||
|
||||
if from_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][from_state] -= size
|
||||
else:
|
||||
# when from_state is FREE, the tensor is new to manager
|
||||
# we should add its memory
|
||||
manager.total_mem[device_type] += size
|
||||
|
||||
if to_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][to_state] += size
|
||||
else:
|
||||
# when to_state is FREE, the tensor will be deleted soon
|
||||
# we should sub its memory
|
||||
manager.total_mem[device_type] -= size
|
||||
|
||||
def __trans_device_update(self, from_type: str, to_type: str):
|
||||
"""Update global manager when changing the device of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
state = self.state
|
||||
|
||||
# update aggregated information
|
||||
manager.total_mem[from_type] -= size
|
||||
manager.total_mem[to_type] += size
|
||||
|
||||
# update the information of each state
|
||||
manager.state_mem[from_type][state] -= size
|
||||
manager.state_mem[to_type][state] += size
|
||||
|
||||
def __del__(self):
|
||||
self.set_null()
|
||||
StatefulTensor.GST_MGR.delete_instance()
|
||||
del self
|
@@ -1,103 +0,0 @@
|
||||
import functools
|
||||
import types
|
||||
from time import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from .stateful_tensor import StatefulTensor, TensorState
|
||||
from .tensor_placement_policy import TensorPlacementPolicy
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
||||
"""
|
||||
Stateful Tensor Manager, inspired from PatrickStar
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
|
||||
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||
|
||||
self._compute_list: List[StatefulTensor] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._warmup = True
|
||||
|
||||
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
|
||||
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
|
||||
self._stateful_tensor_list = tensor_list
|
||||
for t in self._stateful_tensor_list:
|
||||
assert isinstance(t, StatefulTensor)
|
||||
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
||||
|
||||
def start_iter(self):
|
||||
pass
|
||||
|
||||
def finish_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of stateful tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
|
||||
start = time()
|
||||
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
self._cpu_gpu_move_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
|
||||
@property
|
||||
def cpu_gpu_move_volume(self):
|
||||
return self._cpu_gpu_move_volume
|
||||
|
||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||
trans_state_func(state)
|
||||
if state == TensorState.COMPUTE:
|
||||
self._compute_idx += 1
|
||||
if self._warmup:
|
||||
self._compute_list.append(stateful_tensor)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool):
|
||||
move_to_cuda_tensor_list = []
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
raise RuntimeError
|
||||
return move_to_cuda_tensor_list, hold_cuda_tensor_list
|
@@ -1,139 +0,0 @@
|
||||
import functools
|
||||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
|
||||
|
||||
class TensorPlacementPolicy(ABC):
|
||||
|
||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
self.device: Optional[torch.device] = device
|
||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
volume = 0
|
||||
for t in hold_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, self.device)
|
||||
volume += t.payload.numel() * t.payload.element_size()
|
||||
return volume, 0
|
||||
|
||||
|
||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> int:
|
||||
return 0, 0
|
||||
|
||||
|
||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# TODO(ver217): make these args configurable
|
||||
self._warmup_non_model_data_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def evict_tensors(self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[StatefulTensor] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
Args:
|
||||
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||
compute_idx (int, optional): the idx of computing device. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
|
||||
Returns:
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
end = time()
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not warmup:
|
||||
to_free_tensor_list = self._sort_hold_cuda_tensors(tuple(hold_cuda_tensor_list), compute_idx,
|
||||
tuple(compute_list))
|
||||
# print(self._sort_hold_cuda_tensors.cache_info())
|
||||
end = time()
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += t.payload_size
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
return freed_cuda_model_data, end - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_hold_cuda_tensors(hold_cuda_tensors: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensors}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
if compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
|
||||
class TensorPlacementPolicyFactory:
|
||||
|
||||
@staticmethod
|
||||
def create(policy_name: str) -> Type[TensorPlacementPolicy]:
|
||||
if policy_name == 'cpu':
|
||||
return CPUTensorPlacementPolicy
|
||||
elif policy_name == 'cuda':
|
||||
return CUDATensorPlacementPolicy
|
||||
elif policy_name == 'auto':
|
||||
return AutoTensorPlacementPolicy
|
||||
else:
|
||||
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
@@ -1,120 +0,0 @@
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||
return tensor.storage().size() == 0
|
||||
|
||||
|
||||
def free_storage(tensor: torch.Tensor) -> None:
|
||||
if not is_storage_empty(tensor):
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
|
||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||
if is_storage_empty(tensor):
|
||||
tensor.storage().resize_(tensor.numel())
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
t = tensor.payload
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
t = tensor
|
||||
else:
|
||||
return 0, 0
|
||||
|
||||
cuda_use, cpu_use = 0, 0
|
||||
|
||||
mem_use = t.storage().size() * t.element_size()
|
||||
if t.device.type == 'cuda':
|
||||
cuda_use += mem_use
|
||||
elif t.device.type == 'cpu':
|
||||
cpu_use += mem_use
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
|
||||
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
|
||||
torch.Tensor]) -> None:
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
|
||||
NOTE() The source tensor payload will be removed after this function.
|
||||
|
||||
The function will record the communication volume between CPU and GPU.
|
||||
Args:
|
||||
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
||||
"""
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t_payload = src_t.payload
|
||||
else:
|
||||
src_t_payload = src_t.data
|
||||
src_dev = src_t_payload.device
|
||||
|
||||
if isinstance(tgt_t, StatefulTensor):
|
||||
tgt_t_payload = tgt_t.payload
|
||||
else:
|
||||
tgt_t_payload = tgt_t.data
|
||||
|
||||
tgt_t_payload.copy_(src_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t.set_null()
|
||||
else:
|
||||
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||
int]) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||
target_device: a target device, if type is int, it the index of cuda card.
|
||||
"""
|
||||
if not isinstance(target_device, torch.device):
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.to(target_device)
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(target_device)
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
"""colo_model_data_move_to_cpu
|
||||
move a model data tensor from gpu to cpu
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): _description_
|
||||
"""
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.cpu()
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(torch.device('cpu'))
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Clone a model data tensor
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
||||
target_device (torch.device): the target device
|
||||
Returns:
|
||||
torch.Tensor: a cloned torch tensor
|
||||
"""
|
||||
# TODO() rename this function
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t_payload = t.payload if isinstance(t, StatefulTensor) else t
|
||||
return t_payload
|
@@ -1,3 +0,0 @@
|
||||
from .init_context import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
|
||||
|
||||
__all__ = ['ZeroInitContext', 'no_shard_zero_context', 'no_shard_zero_decrator']
|
@@ -1,270 +0,0 @@
|
||||
import contextlib
|
||||
import functools
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ZeroContextConfig:
|
||||
"""The configuration used to control zero context initialization.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
is_replicated (bool, optional): Whether the param is replicated across data parallel group.
|
||||
Some parameters are not replicated, e.g. parameters in MOE experts.
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
"""
|
||||
|
||||
target_device: torch.device
|
||||
is_replicated: bool = True
|
||||
shard_param: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.shard_param:
|
||||
assert self.is_replicated, "Non-replicated parameters can't be sharded."
|
||||
|
||||
if self.is_replicated and not self.shard_param:
|
||||
assert self.target_device.type == 'cuda', "Replicated no-shard parameters should be located in cuda."
|
||||
|
||||
|
||||
class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||
"""A context to initialize model.
|
||||
|
||||
1. Convert the model to fp16.
|
||||
2. The parameters of the module are adapted to type ShardedParameter.
|
||||
3. Shard the param and grad according to flags.
|
||||
|
||||
Args:
|
||||
target_device (torch.device): The device where param data are after exiting the context.
|
||||
shard_strategy (BaseShardStrategy): Shard strategy instance.
|
||||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
target_device: torch.device,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
self.shard_strategy = shard_strategy
|
||||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.bf16 = bf16
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||
|
||||
ZeroContextMgr().current_context = self
|
||||
|
||||
self.param_numel = {}
|
||||
self.top_module = None
|
||||
|
||||
@property
|
||||
def target_device(self):
|
||||
return self.config.target_device
|
||||
|
||||
@property
|
||||
def is_replicated(self):
|
||||
return self.config.is_replicated
|
||||
|
||||
@property
|
||||
def shard_param(self):
|
||||
return self.config.shard_param
|
||||
|
||||
@staticmethod
|
||||
def calc_fanin_fanout(tensor: torch.Tensor):
|
||||
"""We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
|
||||
This can help us get correct fan-in and fan-out for sharded tensor.
|
||||
"""
|
||||
assert isinstance(tensor, nn.Parameter), "Sharded tensor initialization is only allowed for parameters"
|
||||
|
||||
# get correct shape of input tensor
|
||||
if not hasattr(tensor, 'colo_attr') or not tensor.colo_attr.param_is_sharded:
|
||||
tensor_shape = tensor.shape
|
||||
else:
|
||||
tensor_shape = tensor.colo_attr.sharded_data_tensor.origin_shape
|
||||
|
||||
dimensions = len(tensor_shape)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
|
||||
num_input_fmaps = tensor_shape[1]
|
||||
num_output_fmaps = tensor_shape[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
# math.prod is not always available, accumulate the product manually
|
||||
# we could use functools.reduce but that is not supported by TorchScript
|
||||
for s in tensor_shape[2:]:
|
||||
receptive_field_size *= s
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
|
||||
# substitute fan-in and fan-out calculation
|
||||
self.nn_fanin_fanout = nn.init._calculate_fan_in_and_fan_out
|
||||
nn.init._calculate_fan_in_and_fan_out = self.calc_fanin_fanout
|
||||
|
||||
self.module_load_from_state_dict = nn.Module._load_from_state_dict
|
||||
shard_strategy = self.shard_strategy if self.config.shard_param else None
|
||||
nn.Module._load_from_state_dict = functools.partialmethod(ShardedModelV2._colo_load_from_state_dict,
|
||||
shard_strategy=shard_strategy)
|
||||
self.module_state_dict = nn.Module.state_dict
|
||||
nn.Module.state_dict = functools.partialmethod(ShardedModelV2._colo_state_dict,
|
||||
shard_strategy=shard_strategy,
|
||||
state_dict_func=self.module_state_dict,
|
||||
process_group=self.dp_process_group)
|
||||
|
||||
# reserve rng states
|
||||
self.cpu_rng_state = torch.get_rng_state()
|
||||
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||||
|
||||
# set new seed for initialization, since we initialize sharded tensor separately
|
||||
# we don't want all processes have the same seed
|
||||
# otherwise all sharded tensors are same after init
|
||||
offset = self.seed + 1 # we want to have more 1 in binary format seed
|
||||
torch.manual_seed(self.seed + offset * dist.get_rank())
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
# broadcast replicated no-shard parameters
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[0]
|
||||
for param in self.param_list:
|
||||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
del self.param_list
|
||||
|
||||
nn.init._calculate_fan_in_and_fan_out = self.nn_fanin_fanout
|
||||
nn.Module.load_state_dict = self.module_load_from_state_dict
|
||||
nn.Module.state_dict = self.module_state_dict
|
||||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
params = frozenset(self.top_module.parameters())
|
||||
for param in self.param_numel.keys():
|
||||
if param not in params:
|
||||
self.param_numel[param] = 0
|
||||
self.model_numel_tensor.fill_(sum(self.param_numel.values()))
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
self.top_module = module
|
||||
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.to(half_dtype) if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'colo_attr'):
|
||||
continue
|
||||
|
||||
self.param_numel[param] = param.numel()
|
||||
|
||||
# convert parameters to half
|
||||
param_half = half_fn(param)
|
||||
param.data = param_half
|
||||
if param.grad is not None:
|
||||
grad_half = half_fn(param.grad)
|
||||
param.grad.data = grad_half
|
||||
|
||||
# move torch parameters to the target device
|
||||
target_device = self.target_device
|
||||
param.data = param.data.to(target_device)
|
||||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=True)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
||||
# mark whether the param should keep not sharded
|
||||
# if True, the param is used as Zero stage 2
|
||||
param.colo_attr.keep_not_shard = not self.shard_param
|
||||
|
||||
self.param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
# We must cast them
|
||||
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_fn(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
current_context: Optional[ZeroInitContext] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def hijack_context_config(self, **kwargs):
|
||||
if self.current_context is None:
|
||||
yield
|
||||
else:
|
||||
old_config = self.current_context.config
|
||||
self.current_context.config = ZeroContextConfig(**kwargs)
|
||||
yield
|
||||
self.current_context.config = old_config
|
||||
|
||||
|
||||
def no_shard_zero_context(is_replicated: bool = True) -> AbstractContextManager:
|
||||
return ZeroContextMgr().hijack_context_config(target_device=torch.device('cuda', torch.cuda.current_device()),
|
||||
is_replicated=is_replicated,
|
||||
shard_param=False)
|
||||
|
||||
|
||||
def no_shard_zero_decrator(is_replicated: bool = True):
|
||||
|
||||
def _wrapper(init_func):
|
||||
|
||||
def _no_shard(*args, **kwargs):
|
||||
with no_shard_zero_context(is_replicated):
|
||||
ret = init_func(*args, **kwargs)
|
||||
return ret
|
||||
|
||||
return _no_shard
|
||||
|
||||
return _wrapper
|
@@ -1,5 +0,0 @@
|
||||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
@@ -1,22 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class BaseShardStrategy(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Abstract Shard Strategy. Use to shard a tensors on multiple GPUs.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
pass
|
@@ -1,47 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors as flatten
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
|
||||
|
||||
class BucketTensorShardStrategy(TensorShardStrategy):
|
||||
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
|
||||
which will fully utilize network bandwidth.
|
||||
It is especially useful when sub-module contains bias,
|
||||
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usually small).
|
||||
"""
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
|
||||
tensor_list: List[ShardedTensor] = [t for t in tensor_list if t.is_sharded]
|
||||
if len(tensor_list) == 0:
|
||||
return
|
||||
target_device = tensor_list[0].device
|
||||
dtype = tensor_list[0].dtype
|
||||
buffer_list: List[torch.Tensor] = []
|
||||
tensor_numels = [t.payload.numel() for t in tensor_list]
|
||||
buffer_size = sum(tensor_numels)
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
# Move to target device before splitting buffer
|
||||
# Ensure we utilize maximum PCIE bandwidth
|
||||
buffer_list = [buffer.to(target_device) for buffer in buffer_list]
|
||||
offset = 0
|
||||
for i, t in enumerate(tensor_list):
|
||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||
t.payload_reset(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
@@ -1,22 +0,0 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
|
||||
"""Return the local shard of a full tensor."""
|
||||
# Shard using torch.chunk to match all-gather/reduce-scatter.
|
||||
chunks = list(torch.flatten(tensor).chunk(world_size))
|
||||
while len(chunks) < world_size:
|
||||
chunks.append(chunks[0].new_empty(0))
|
||||
|
||||
# Determine number of padding elements.
|
||||
num_to_pad = chunks[0].numel() - chunks[rank].numel()
|
||||
assert num_to_pad >= 0, num_to_pad
|
||||
|
||||
shard = torch.zeros_like(chunks[0])
|
||||
length = chunks[rank].size(0)
|
||||
shard_temp = shard[:length]
|
||||
shard_temp.copy_(chunks[rank])
|
||||
|
||||
return shard, num_to_pad
|
@@ -1,59 +0,0 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.shard_utils.commons import get_shard
|
||||
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
"""
|
||||
A naive implementation which shard each tensor evenly over all ranks
|
||||
"""
|
||||
|
||||
def shard(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._shard_tensor(t, process_group)
|
||||
|
||||
def gather(self, tensor_list: List[ShardedTensor], process_group: Optional[dist.ProcessGroup] = None):
|
||||
for t in tensor_list:
|
||||
self._gather_tensor(t, process_group)
|
||||
|
||||
def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
""" Shard tensor among processes.
|
||||
|
||||
Args:
|
||||
t (ShardedTensor): a tensor to be sharded.
|
||||
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
|
||||
Defaults to None.
|
||||
"""
|
||||
if t.is_sharded:
|
||||
return
|
||||
if t.payload.device.type == 'cuda':
|
||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.payload_reset(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
if not t.is_sharded:
|
||||
return
|
||||
target_device = t.device
|
||||
payload_numel = t.payload.numel()
|
||||
world_size = dist.get_world_size(process_group)
|
||||
rank = dist.get_rank(process_group)
|
||||
|
||||
buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device())
|
||||
buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0))
|
||||
buffer_list[rank].copy_(t.payload)
|
||||
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.payload_reset(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
@@ -1,3 +0,0 @@
|
||||
from .sharded_model_v2 import ShardedModelV2
|
||||
|
||||
__all__ = ['ShardedModelV2']
|
@@ -1,85 +0,0 @@
|
||||
from typing import Any, Callable, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||
factor: int = 1
|
||||
while world_size % factor == 0 and world_size / factor > factor:
|
||||
factor *= 2
|
||||
return float(factor)
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
"""Free underlying storage of a Tensor."""
|
||||
if data.storage().size() > 0:
|
||||
# Since we're modifying the Tensor's Storage directly, make sure the Tensor
|
||||
# is the sole occupant of the Storage.
|
||||
assert data.storage_offset() == 0
|
||||
data.storage().resize_(0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
||||
"""Allocate storage for a tensor."""
|
||||
if data.storage().size() == size.numel(): # no need to reallocate
|
||||
return
|
||||
assert data.storage().size() == 0
|
||||
data.storage().resize_(size.numel())
|
||||
|
||||
|
||||
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.half()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
|
||||
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.bfloat16()
|
||||
return tensor
|
||||
|
||||
|
||||
def apply_to_tensors(x: Any, fn: Callable):
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
elif isinstance(x, list):
|
||||
return [apply_to_tensors(t, fn) for t in x]
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(apply_to_tensors(t, fn) for t in x)
|
||||
elif isinstance(x, dict):
|
||||
return {key: apply_to_tensors(val, fn) for key, val in x.items()}
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def cast_float_arguments(fn: Callable, *args: Any, **kwargs: Any) -> Tuple[Any, Any]:
|
||||
return apply_to_tensors(args, fn), apply_to_tensors(kwargs, fn)
|
||||
|
||||
|
||||
def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
|
||||
"""Chunk a given Tensor into num_chunks parts and add any necessary padding."""
|
||||
chunks = list(torch.flatten(tensor).chunk(num_chunks))
|
||||
# torch.chunk may return fewer than num_chunks chunks, pad accordingly.
|
||||
num_pad_for_partial_chunk = chunks[0].numel() - chunks[-1].numel()
|
||||
if num_pad_for_partial_chunk > 0:
|
||||
chunks[-1] = F.pad(chunks[-1], [0, num_pad_for_partial_chunk])
|
||||
if len(chunks) < num_chunks:
|
||||
chunks.extend([torch.zeros_like(chunks[0]) for _ in range(num_chunks - len(chunks))])
|
||||
return chunks
|
@@ -1,200 +0,0 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
# TODO: Remove the toggle-enable_nccl_base_collectives when github open issue #801 is resolved.
|
||||
if os.getenv("ENABLE_NCCL_BASE_COLLECTIVES", "1") == "0":
|
||||
enable_nccl_base_collectives = False
|
||||
else:
|
||||
enable_nccl_base_collectives = True
|
||||
|
||||
|
||||
class Bucket:
|
||||
|
||||
def __init__(self, shard_size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
|
||||
self.buffer = torch.zeros((group.size(), shard_size), dtype=dtype, device=device)
|
||||
self.group = group
|
||||
self.offset = 0
|
||||
self.callbacks: List[Callable] = []
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush content of the bucket."""
|
||||
if self.offset == 0:
|
||||
assert len(self.callbacks) == 0
|
||||
return
|
||||
# reduce-scatter bucket
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
dist._reduce_scatter_base(self.output_shard[:self.offset],
|
||||
self.buffer[:, :self.offset].contiguous(),
|
||||
group=self.group)
|
||||
else:
|
||||
dist.reduce_scatter(self.output_shard[:self.offset],
|
||||
list(self.buffer[:, :self.offset].unbind(0)),
|
||||
group=self.group)
|
||||
# execute post-reduction callbacks
|
||||
for callback_fn in self.callbacks:
|
||||
callback_fn()
|
||||
# reuse input bucket but allocate a fresh output shard
|
||||
self.buffer[:, :self.offset].zero_()
|
||||
self.offset = 0
|
||||
self.callbacks.clear()
|
||||
self.output_shard = torch.zeros_like(self.buffer[0])
|
||||
|
||||
def alloc(self) -> None:
|
||||
"""Setup the buffers if they are not allocated.
|
||||
|
||||
Using ``setup`` and ``teardown``, we can ensure that the bucket
|
||||
buffers are only allocated during the backward pass, hence saving more
|
||||
memory to other parts of the training process, such as the forward pass
|
||||
for activation memory.
|
||||
"""
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
if tensor.storage().size() == 0:
|
||||
tensor.storage().resize_(tensor.size().numel())
|
||||
|
||||
def free(self) -> None:
|
||||
"""Tear down the bucket by freeing the memory"""
|
||||
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
|
||||
for tensor in [self.buffer, self.output_shard]:
|
||||
tensor.storage().resize_(0)
|
||||
|
||||
def append(self, tensor_list: List[Tensor], callback_fn: Callable):
|
||||
# copy data from input_list into bucket
|
||||
tensor_size = tensor_list[0].numel()
|
||||
stacked_input = torch.stack(tensor_list).view(self.group.size(), tensor_size)
|
||||
offset = self.offset
|
||||
self.buffer[:, offset:offset + tensor_size].copy_(stacked_input)
|
||||
self.offset += tensor_size
|
||||
|
||||
# callback will be given the reduced result
|
||||
if callback_fn is not None:
|
||||
result_view = self.output_shard[offset:offset + tensor_size].view_as(tensor_list[0])
|
||||
self.callbacks.append(functools.partial(callback_fn, result_view))
|
||||
|
||||
|
||||
class ReduceScatterBucketer:
|
||||
"""
|
||||
Helper for bucketing multiple reduce-scatter operations on small tensors
|
||||
into larger reduce-scatter ops to improve communication efficiency.
|
||||
|
||||
Usage::
|
||||
|
||||
bucketer = ReduceScatterBucketer()
|
||||
bucketer.reduce_scatter_async(
|
||||
small_tensors, callback_fn=lambda result: print("small")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
big_tensors, callback_fn=lambda result: print("big")
|
||||
)
|
||||
bucketer.reduce_scatter_async(
|
||||
more_small_tensors, callback_fn=lambda result: print("small2")
|
||||
)
|
||||
bucketer.flush() # callbacks only guaranteed to be called after flush()
|
||||
# Example output (note that it is out of order, due to bucketing):
|
||||
# big
|
||||
# small
|
||||
# small2
|
||||
|
||||
Args:
|
||||
bucket_size_mb (int, Optional): bucket size for communicating. Buckets
|
||||
are sub-divided based on world_size. Values <= 0 disable bucketing.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket_size_mb: int = 25):
|
||||
self.bucket_size_mb = bucket_size_mb
|
||||
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
|
||||
|
||||
@torch.no_grad()
|
||||
def reduce_scatter_async(
|
||||
self,
|
||||
input_list: List[Tensor],
|
||||
group: ProcessGroup,
|
||||
callback_fn: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Reduce-scatter a list of tensors asynchronously, so smaller reductions
|
||||
can be bucketed together. The given callback (``callback_fn``) will be
|
||||
called with the reduced result at some later time. Call ``flush()`` to
|
||||
force all queued ops and callbacks to be executed.
|
||||
|
||||
Note that large inputs will be reduced immediately, and this function
|
||||
may also flush the relevant bucket to make room for ``input_list``.
|
||||
|
||||
Args:
|
||||
input_list (List[Tensor]): list of tensors to reduce-scatter. List
|
||||
should contain ``group.size()`` tensors and each tensor should
|
||||
have identical shape, dtype and device.
|
||||
group (ProcessGroup): process group for reduction
|
||||
callback_fn (Callable, Optional): callback function to call after
|
||||
the reduction executes. Function will be called with a single
|
||||
argument corresponding to the reduced result.
|
||||
"""
|
||||
world_size = group.size()
|
||||
|
||||
assert (len(input_list) == world_size
|
||||
), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"
|
||||
|
||||
first_input = input_list[0]
|
||||
first_input_size = first_input.numel()
|
||||
|
||||
bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size)
|
||||
if first_input_size > bucket_shard_size:
|
||||
# TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
|
||||
# input is too big to fit in the bucket, reduce-scatter directly
|
||||
output = torch.zeros_like(input_list[0])
|
||||
if hasattr(dist, "_reduce_scatter_base") and enable_nccl_base_collectives:
|
||||
input_flattened = torch.cat(input_list)
|
||||
dist._reduce_scatter_base(output, input_flattened, group=group)
|
||||
else:
|
||||
# fallback
|
||||
dist.reduce_scatter(output, input_list, group=group)
|
||||
if callback_fn is not None:
|
||||
callback_fn(output)
|
||||
return
|
||||
|
||||
bucket = self._get_bucket(first_input, group)
|
||||
if first_input_size > bucket.buffer.size(1) - bucket.offset:
|
||||
# not enough space remaining in bucket, flush it now
|
||||
bucket.flush()
|
||||
bucket.append(input_list, callback_fn)
|
||||
|
||||
@torch.no_grad()
|
||||
def flush(self) -> None:
|
||||
"""Reduce-scatter any partial buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.flush()
|
||||
|
||||
@torch.no_grad()
|
||||
def free(self) -> None:
|
||||
"""Free buffers from all buckets."""
|
||||
for bucket in self.buckets.values():
|
||||
bucket.free()
|
||||
|
||||
@functools.lru_cache()
|
||||
def _get_shard_size(self, element_size: int, num_shards: int) -> int:
|
||||
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
|
||||
return 0
|
||||
MB = 1024 * 1024
|
||||
bucket_size = self.bucket_size_mb * MB / element_size
|
||||
return int(bucket_size // num_shards)
|
||||
|
||||
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
|
||||
key = (tensor.dtype, tensor.device, group)
|
||||
if key not in self.buckets:
|
||||
# buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
|
||||
world_size = group.size()
|
||||
shard_size = self._get_shard_size(tensor.element_size(), world_size)
|
||||
self.buckets[key] = Bucket(shard_size, tensor.dtype, tensor.device, group)
|
||||
self.buckets[key].alloc()
|
||||
return self.buckets[key]
|
@@ -1,577 +0,0 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import functools
|
||||
import itertools
|
||||
from collections import OrderedDict
|
||||
from copy import deepcopy
|
||||
from typing import Any, Iterator, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import disposable, get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
|
||||
from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
cast_tensor_to_bf16,
|
||||
cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32,
|
||||
chunk_and_pad,
|
||||
free_storage,
|
||||
get_gradient_predivide_factor,
|
||||
)
|
||||
from .zero_hook import ZeroHook
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
"""
|
||||
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
|
||||
Only `1/#nproc` of parameters, gradients are stored in local CUDA memory, so forward and backward
|
||||
passes can be executed with limited CUDA memory budget.
|
||||
|
||||
Note:
|
||||
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
|
||||
Note:
|
||||
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
|
||||
if you enable ``reuse_fp16_shard``.
|
||||
|
||||
Args:
|
||||
module (nn.Module): A sharded module, which must be initialized by `ZeroInitContext`.
|
||||
shard_strategy (BaseShardStrategy): A shard strategy to manage shard behavior.
|
||||
process_group (Optional[ProcessGroup], optional): Data parallel process group. Defaults to None.
|
||||
reduce_scatter_process_group (Optional[ProcessGroup], optional): Reduce-scatter process group.
|
||||
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
|
||||
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
|
||||
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
|
||||
tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
|
||||
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
||||
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||
Note that 'auto' policy can only work well when no other processes use CUDA during your training.
|
||||
Defaults to 'cuda'.
|
||||
gradient_predivide_factor (Optional[float], optional): Gradient is divided by this value before reduce-scatter. Defaults to 1.0.
|
||||
reuse_fp16_shard (bool, optional): Whether to reuse fp16 shard for param and grad.
|
||||
Enabling this can reduce GPU memory usage, but you have to make sure you disable it when using gradient accumulation.
|
||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: nn.Module,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
self.bf16 = bf16
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
for submodule in module.modules():
|
||||
sharded_cnt = 0
|
||||
unshard_cnt = 0
|
||||
for param in submodule.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_cnt += 1
|
||||
else:
|
||||
unshard_cnt += 1
|
||||
assert (not sharded_cnt) or (not unshard_cnt), 'nn.Module can not both have shard param and unshard param'
|
||||
submodule.param_is_sharded = (sharded_cnt > 0)
|
||||
|
||||
self.sharded_params = []
|
||||
self.unshard_params = []
|
||||
for param in module.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
self.sharded_params.append(param)
|
||||
else:
|
||||
self.unshard_params.append(param)
|
||||
|
||||
self.module = module
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.reduce_scatter_process_group = reduce_scatter_process_group or self.process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
|
||||
if 'warmup_non_model_data_ratio' in kwargs:
|
||||
if tensor_placement_policy != 'auto':
|
||||
self.logger.warning('setting warmup_non_model_data_ratio is useless if not use auto placement')
|
||||
else:
|
||||
ratio = kwargs['warmup_non_model_data_ratio']
|
||||
self._tensor_placement_policy._warmup_non_model_data_ratio = ratio
|
||||
self.logger.info(f'setting warmup_non_model_data_ratio as {ratio} for auto placement')
|
||||
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||
]
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
|
||||
for param in module.parameters():
|
||||
# Init `offload_grad`
|
||||
param.colo_attr.offload_grad = self._cpu_offload
|
||||
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set
|
||||
# gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if \
|
||||
gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||
self._require_backward_grad_sync: bool = True
|
||||
|
||||
self._cuda_margin_space = 0
|
||||
self.reuse_fp16_shard = reuse_fp16_shard
|
||||
|
||||
# record whether gradients have inf or nan
|
||||
self.overflow_counter = 0
|
||||
|
||||
def adjust_stateful_tensor_layout(self) -> None:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
|
||||
@property
|
||||
def use_memory_tracer(self):
|
||||
return self._use_memory_tracer
|
||||
|
||||
@property
|
||||
def cuda_margin_space(self):
|
||||
return self._cuda_margin_space
|
||||
|
||||
@property
|
||||
def cpu_offload(self):
|
||||
return self._cpu_offload
|
||||
|
||||
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
||||
"""
|
||||
dummy memory tracer collected information to a file.
|
||||
try:
|
||||
# forward: model(inputs)
|
||||
# backward: optimizer.backward()
|
||||
except Exception as e:
|
||||
model.dump_memory_stats()
|
||||
exit(0)
|
||||
"""
|
||||
if self._use_memory_tracer:
|
||||
self.logger.error(f'dump memory tracer collected information to a {filename}', ranks=[0])
|
||||
if gpc.get_global_rank() == 0:
|
||||
with open(filename, 'w+') as f:
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
f.write('CPU non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cpu')))
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self, *args):
|
||||
# the operation will affect the memory tracer behavior in ZeroHook
|
||||
if self._memstats_collector:
|
||||
self._start_collect_memstats()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
self._stateful_tensor_mgr.start_iter()
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations(*args)
|
||||
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
||||
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self._post_forward_operations()
|
||||
return outputs
|
||||
|
||||
def backward(self, loss):
|
||||
loss.backward()
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
||||
def _update_memstats(self):
|
||||
if self._memstats_collector:
|
||||
self._finish_collect_memstats()
|
||||
# cuda margin space = cuda mem capacity - max fwd/bwd cuda mem used.
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(
|
||||
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
"""
|
||||
The method includes operations required to be processed after backward
|
||||
1. update memory tracer.
|
||||
2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||
3. shard tensors not dealed in the zero hook
|
||||
4. move sharded param grad payload to param.grad
|
||||
"""
|
||||
# 1. update memory tracer.
|
||||
self._update_memstats()
|
||||
|
||||
# 2. flush the gradient in buckets. Reducing partial gradients in each process.
|
||||
if self._require_backward_grad_sync:
|
||||
# Flush any unreduced buckets in the post_backward stream.
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
self.reducer.flush()
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
self.reducer.free()
|
||||
|
||||
# 3. shard tensors not dealed in the zero hook
|
||||
tensor_list = []
|
||||
for p in self.sharded_params:
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.set_data_none()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. set all parameters' grad to None
|
||||
for p in self.module.parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
# Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass.
|
||||
# NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient all reducing between process group.
|
||||
# If _require_backward_grad_sync is True,
|
||||
# p.grad remains the accumulated unsharded gradient from prior no-sync passes.
|
||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
|
||||
p.grad = None
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
param.grad (GPU #0): [1, 2, 3, 4]
|
||||
param.grad (GPU #1): [5, 6, 7, 8]
|
||||
|
||||
after reduce_scatter:
|
||||
param.grad (GPU #0): [6, 8] # 1+5, 2+6
|
||||
param.grad (GPU #1): [10, 12] # 3+7, 4+8
|
||||
|
||||
The local GPU's ``optim.step`` is responsible for updating a single
|
||||
shard of params, also corresponding to the current GPU's rank. This
|
||||
alignment is created by `param.colo_attr.grad`, which ensures that
|
||||
the local optimizer only sees the relevant parameter shard.
|
||||
"""
|
||||
if grad is None:
|
||||
return
|
||||
assert not grad.requires_grad, 'ShardedModel only works with gradients that don\'t require gradients'
|
||||
if not self._require_backward_grad_sync:
|
||||
return
|
||||
# used to cheat Pytorch, since we can't return None
|
||||
empty_grad = torch.empty_like(grad)
|
||||
free_storage(empty_grad)
|
||||
# As torch didn't allow modifying grad in hook, we make a copy
|
||||
grad = grad.clone()
|
||||
if param.colo_attr.is_replicated:
|
||||
self._reduce_scatter_handler(param, grad)
|
||||
else:
|
||||
self._save_grad(param, grad)
|
||||
return empty_grad
|
||||
|
||||
def _reduce_scatter_handler(self, param: Parameter, grad: torch.Tensor) -> None:
|
||||
self.comm_stream.wait_stream(torch.cuda.current_stream())
|
||||
with torch.cuda.stream(self.comm_stream):
|
||||
if self.fp32_reduce_scatter:
|
||||
grad.data = grad.data.to(param.dtype)
|
||||
if self.gradient_predivide_factor > 1.0:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
grad.data.div_(self.gradient_predivide_factor)
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(grad, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, grad)
|
||||
torch.cuda.current_stream().wait_stream(self.comm_stream)
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
assert isinstance(reduced_grad,
|
||||
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||
reduced_grad.data = reduced_grad.data.contiguous().view(-1)
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
self._save_grad(param, reduced_grad)
|
||||
|
||||
# FIXME(ver217): refactor the below line when impl eviction policy
|
||||
def _save_grad(self, param: Parameter, grad: torch.Tensor):
|
||||
|
||||
# record whether we have overflow
|
||||
self.overflow_counter += torch.isinf(grad).any().item()
|
||||
self.overflow_counter += torch.isnan(grad).any().item()
|
||||
|
||||
# move gradient to cpu
|
||||
if param.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad)
|
||||
|
||||
if self.reuse_fp16_shard:
|
||||
# make parameters point to gradient
|
||||
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradient accumulation is not supported when reuse_fp16_shard=True'
|
||||
|
||||
param.colo_attr.grad_payload_reset(grad.data)
|
||||
# release the memory of param
|
||||
# we set a false None for parameter's payload
|
||||
# so we can get parameter's device and dtype later in optimizer
|
||||
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
else:
|
||||
|
||||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.grad_payload_reset(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||
|
||||
# keep saved_grad in HOLD state
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
return self.module.parameters(recurse=recurse)
|
||||
|
||||
def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, Parameter]]:
|
||||
return self.module.named_parameters(prefix, recurse)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
return self._colo_state_dict(destination,
|
||||
prefix,
|
||||
keep_vars,
|
||||
shard_strategy=self.shard_strategy,
|
||||
state_dict_func=nn.Module.state_dict,
|
||||
module_to_load=self.module,
|
||||
sharded_params=self.sharded_params,
|
||||
process_group=self.process_group)
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True) -> None:
|
||||
for name, p in self.named_parameters():
|
||||
if name in state_dict:
|
||||
p.colo_attr.data_payload_reset(state_dict[name].to(dtype=p.colo_attr.data_payload.dtype,
|
||||
device=p.colo_attr.data_payload.device))
|
||||
# Force re-shard
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor])
|
||||
elif strict:
|
||||
raise RuntimeError(f'Missing key in state_dict: {name}')
|
||||
|
||||
def _colo_state_dict(self,
|
||||
destination=None,
|
||||
prefix='',
|
||||
keep_vars=False,
|
||||
shard_strategy: Optional[BaseShardStrategy] = None,
|
||||
state_dict_func=None,
|
||||
module_to_load=None,
|
||||
sharded_params=[],
|
||||
process_group=None) -> 'OrderedDict[str, torch.Tensor]':
|
||||
if len(sharded_params) == 0:
|
||||
for param in self.parameters():
|
||||
if param.colo_attr.param_is_sharded:
|
||||
sharded_params.append(param)
|
||||
if shard_strategy is not None:
|
||||
shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||
for p in sharded_params:
|
||||
p.data = p.colo_attr.data_payload
|
||||
module_to_load = module_to_load or self
|
||||
gathered_state_dict = state_dict_func(module_to_load, destination, prefix, keep_vars)
|
||||
gathered_state_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in gathered_state_dict.items()}
|
||||
if shard_strategy is not None:
|
||||
shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in sharded_params], process_group)
|
||||
for p in sharded_params:
|
||||
p.colo_attr.set_data_none()
|
||||
return gathered_state_dict
|
||||
|
||||
def _colo_load_from_state_dict(self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
shard_strategy=None):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
shard_strategy (Optional[BaseShardStrategy], optional): A shard strategy to manage shard behavior. Defaults to None.
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if hasattr(param, 'colo_attr'):
|
||||
param.colo_attr.data_payload_reset(
|
||||
input_param.to(dtype=param.colo_attr.data_payload.dtype,
|
||||
device=param.colo_attr.data_payload.device))
|
||||
if shard_strategy is not None:
|
||||
# Force re-shard
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = False
|
||||
shard_strategy.shard([param.colo_attr.sharded_data_tensor])
|
||||
else:
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(
|
||||
key, input_param.shape, param.shape))
|
||||
continue
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||
ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", nn.Module.set_extra_state) is not nn.Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
assert isinstance(self.module, nn.ModuleList)
|
||||
return self.module[idx]
|
||||
|
||||
def __len__(self):
|
||||
assert isinstance(self.module, nn.ModuleList)
|
||||
return len(self.module)
|
||||
|
||||
def __iter__(self):
|
||||
assert isinstance(self.module, nn.ModuleList)
|
||||
return iter(self.module)
|
@@ -1,20 +0,0 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
|
||||
|
||||
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
|
||||
"""
|
||||
copy param of the ShardedModelV2 to other_model.
|
||||
Note the other_model has to be the same as self.
|
||||
"""
|
||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||
assert hasattr(zero_param, 'colo_attr')
|
||||
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|
@@ -1,118 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.legacy.registry import OPHOOKS
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ZeroHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param for ZeRO method.
|
||||
Warning: this class has been deprecated after version 0.1.12
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger("ZeROHook")
|
||||
self.shard_strategy = shard_strategy
|
||||
self.process_group = process_group
|
||||
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = get_current_device()
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||
|
||||
def gather_parameters(self, module: torch.nn.Module):
|
||||
# gather sharded parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
|
||||
def shard_parameters(self, module: torch.nn.Module):
|
||||
# shard gathered parameters
|
||||
if module.param_is_sharded:
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
def adjust_module_data(self, module: torch.nn.Module):
|
||||
# record overall data statistics
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_overall_data()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
# adjust stateful tensor to get enough CUDA memory
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
|
||||
# record model data statistics
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.record_model_data_volume()
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
self.adjust_module_data(module)
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
||||
# change tensor state to HOLD_AFTER_FWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
|
||||
self.shard_parameters(module)
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
self.adjust_module_data(module)
|
||||
self.gather_parameters(module)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
||||
# change tensor state to HOLD_AFTER_BWD
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
self.shard_parameters(module)
|
||||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
if self._stateful_tensor_mgr:
|
||||
self.logger.debug(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB, get layout info time: {self._stateful_tensor_mgr._layout_time}, evict cpu time: {self._stateful_tensor_mgr._evict_time}",
|
||||
ranks=[0])
|
||||
self._stateful_tensor_mgr.finish_iter()
|
@@ -1,3 +0,0 @@
|
||||
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
__all__ = ['ShardedOptimizerV2']
|
@@ -1,399 +0,0 @@
|
||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
from enum import Enum
|
||||
from os import stat
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.zero.legacy.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 1
|
||||
UNSCALED = 2
|
||||
|
||||
|
||||
class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||
"""A wrapper for optimizer. ``ShardedOptimizerV2`` and ``ShardedModelV2`` implement Zero Redundancy Optimizer (ZeRO).
|
||||
|
||||
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
|
||||
|
||||
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
|
||||
|
||||
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
|
||||
|
||||
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
|
||||
which is detected by a runtime memory tracer.
|
||||
|
||||
We place as many OS chunks in the margin space as possible.
|
||||
|
||||
The size of margin space can be controlled by ``gpu_margin_mem_ratio``.
|
||||
If it is set as ``0.0``, it is the same as classical ZeRO optimizer.
|
||||
|
||||
Note:
|
||||
You must use ``ShardedOptimizerV2`` with ``ShardedModelV2``.
|
||||
|
||||
Note:
|
||||
Make sure you set ``tensor_placement_policy`` in ``ShardedModelV2`` to `"auto"`,
|
||||
if you set ``gpu_margin_mem_ratio > 0``.
|
||||
|
||||
Args:
|
||||
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||
optimizer (Optimizer): An Optimizer instance.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
||||
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
||||
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
||||
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
dp_process_group (Optional[ProcessGroup], optional): data parallel process group. Defaults to None.
|
||||
mp_process_group (Optional[ProcessGroup], optional): model parallel process group. Defaults to None.
|
||||
|
||||
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
sharded_model: ShardedModelV2,
|
||||
optimizer: Optimizer,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32,
|
||||
dp_process_group: Optional[ProcessGroup] = None,
|
||||
mp_process_group: Optional[ProcessGroup] = None,
|
||||
verbose: bool = False) -> None:
|
||||
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
||||
assert not isinstance(optimizer, ShardedOptimizerV2), 'Nested ShardedOptimizerV2 is not supported.'
|
||||
|
||||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
self.model: ShardedModelV2 = sharded_model
|
||||
self.bf16 = sharded_model.bf16
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optimizer, 'num_fp32_shards_per_param', 0) >= 2
|
||||
self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
|
||||
self.optim_state: OptimState = OptimState.UNSCALED
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
self._verbose = verbose
|
||||
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
if self.gpu_margin_mem_ratio != 0.0 and not isinstance(sharded_model._tensor_placement_policy,
|
||||
AutoTensorPlacementPolicy):
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"',
|
||||
ranks=[0])
|
||||
|
||||
if self._verbose:
|
||||
self._logger.debug(
|
||||
f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!", ranks=[0])
|
||||
|
||||
self._use_memory_tracer = self.model.use_memory_tracer
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
""" Get the memory usage of the optimizer. Including master_params (param fp32),
|
||||
momentum (``self.state[p]['exp_avg']``) variance (``self.state[p]['exp_avg_sq']``)
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: cuda/cpu memory usage in Byte.
|
||||
"""
|
||||
cuda_use = 0
|
||||
cpu_use = 0
|
||||
|
||||
def update_mem_use(t):
|
||||
nonlocal cuda_use
|
||||
nonlocal cpu_use
|
||||
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
|
||||
cuda_use += t_cuda_use
|
||||
cpu_use += t_cpu_use
|
||||
|
||||
for _, p_fp32 in self.master_params.items():
|
||||
update_mem_use(p_fp32)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
update_mem_use(v)
|
||||
|
||||
return cuda_use, cpu_use
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self._zero_grad()
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
if not self.bf16:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
# This function is called except the last stage of pipeline parallel
|
||||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
if not self.bf16:
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
||||
self._prepare_grads()
|
||||
# unscale grads if scaled
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
||||
self._maybe_move_fp32_shards()
|
||||
if not self.bf16:
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
self._point_param_fp16_to_master_param()
|
||||
|
||||
if self._verbose:
|
||||
gpu_mem, cpu_mem = self.get_memory_usage()
|
||||
self._logger.debug(
|
||||
f"Before step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
|
||||
if self._verbose:
|
||||
gpu_mem, cpu_mem = self.get_memory_usage()
|
||||
self._logger.debug(
|
||||
f"After step ShardedOptimizerV2 consumes {gpu_mem / 1e6} MB CUDA Memory, {cpu_mem / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
||||
self._copy_master_model_to_model_fp16()
|
||||
return ret
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.model.overflow_counter)
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, group=self.dp_process_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
dist.all_reduce(self._found_overflow, group=self.mp_process_group)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _unscale_grads(self):
|
||||
assert self.optim_state == OptimState.SCALED
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is not None:
|
||||
p.grad.data.div_(self.loss_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
def _zero_grad(self, recover_data: bool = False):
|
||||
"""zero grad and maybe recover fp16 params
|
||||
When `reuse_fp16_shard` is enabled,
|
||||
p.colo_attr.sharded_data_tensor stores grad here.
|
||||
We have to recover them from fp32 params.
|
||||
|
||||
Args:
|
||||
recover_data (bool, optional): Whether to recover fp16 param from fp32 param. Defaults to False.
|
||||
"""
|
||||
# We must set grad to None
|
||||
# Because grad here is sharded
|
||||
# But next backward pass will create a full grad first
|
||||
# Which leads to wrong accumulation
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
|
||||
if recover_data and reuse_fp16_shard:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
else:
|
||||
# release saved gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
self.model.overflow_counter = 0 # set overflow counter to zero
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
||||
def _register_master_weight(self):
|
||||
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
shard_flag = not p.colo_attr.sharded_data_tensor.is_sharded and p.colo_attr.is_replicated
|
||||
if shard_flag:
|
||||
# we always shard replicated parameters
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
|
||||
if shard_flag:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
def _maybe_move_fp32_shards(self):
|
||||
if self._should_move_fp32_shards_h2d:
|
||||
self._should_move_fp32_shards_h2d = False
|
||||
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
||||
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_shards_used_cuda_margin_mem = 0
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
continue
|
||||
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
p.colo_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor):
|
||||
state[k] = v.cuda()
|
||||
|
||||
def _prepare_grads(self):
|
||||
if self._grad_prepared:
|
||||
return
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
continue
|
||||
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||
# If reuse_fp16_shard, grad fp16 which wasn't be offloaded may be evicted to CPU
|
||||
if not p.colo_attr.offload_grad:
|
||||
colo_model_data_tensor_move_inline(p.colo_attr.saved_grad, torch.cuda.current_device())
|
||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful information
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||
p.data = p.colo_attr.grad_payload
|
||||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.set_data_none()
|
||||
self._grad_prepared = True
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
# We will not trigger data copy here.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
self.master_params[p].trans_state(TensorState.COMPUTE)
|
||||
p.data = self.master_params[p].payload
|
||||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
|
||||
def _copy_master_model_to_model_fp16(self):
|
||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transferring
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
|
||||
def _copy_master_param_to_param_fp16(self, p):
|
||||
# flush gradient
|
||||
if p.colo_attr.sharded_data_tensor.payload_size == 0:
|
||||
# here reuse_fp16_shard is True
|
||||
# in order to use copy below, we should give sharded data tensor a payload
|
||||
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
|
||||
else:
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
p.data = self.master_params[p].payload
|
||||
|
||||
# we need to allocate new memory for keep_not_shard parameters
|
||||
# in order to use copy, otherwise, the sizes of tensor is not compatible
|
||||
if p.colo_attr.data_payload.numel() != p.data.numel():
|
||||
p.colo_attr.data_payload_reset(
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
# We gather full fp16 param here
|
||||
p.colo_attr.sharded_data_tensor.is_sharded = True # since only gradient is sharded, we should set to True
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
self.master_params[p].trans_state(TensorState.HOLD)
|
||||
|
||||
def state_dict(self):
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
return optim_state_dict
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
if 'scaler' not in args[0]:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
else:
|
||||
scaler_state_dict = args[0].pop('scaler')
|
||||
self.grad_scaler.load_state_dict(scaler_state_dict)
|
||||
super().load_state_dict(*args, **kwargs)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor):
|
||||
state[k] = v.to(dtype=self.master_params[p].dtype, device=self.master_params[p].device)
|
@@ -1,4 +0,0 @@
|
||||
from .sharded_param import ShardedParamV2
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
@@ -1,110 +0,0 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
|
||||
from .sharded_tensor import ShardedTensor
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
||||
|
||||
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
||||
key = (device, dtype)
|
||||
if key not in EMPTY_TENSOR_DICT:
|
||||
EMPTY_TENSOR_DICT[key] = torch.empty(0, dtype=dtype, device=device)
|
||||
|
||||
return EMPTY_TENSOR_DICT[key]
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
self.offload_grad: bool = False
|
||||
|
||||
# make sure the shared param is the only owner of payload
|
||||
# The param.data maybe used to init the other part of the model.
|
||||
# For example: File "resnet.py", line 190, in __init__
|
||||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if set_data_none:
|
||||
self.set_data_none()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def set_data_none(self):
|
||||
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
|
||||
|
||||
def set_grad_none(self):
|
||||
self.saved_grad.set_null()
|
||||
|
||||
@property
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def data_payload(self):
|
||||
assert not self.sharded_data_tensor.is_null()
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
def grad_payload(self):
|
||||
assert not self.saved_grad.is_null()
|
||||
return self.saved_grad.payload
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def data_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.payload_reset(tensor)
|
||||
|
||||
def grad_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.payload_reset(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
get the memory usage of the param, including data and grad
|
||||
Returns:
|
||||
Tuple[int, int]: cuda mem usage in Byte, cpu memory usage in Byte
|
||||
"""
|
||||
cuda_mem_use, cpu_mem_use = 0, 0
|
||||
|
||||
def _update_mem_use(t: Optional[torch.Tensor]):
|
||||
if t is None:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor)
|
||||
nonlocal cuda_mem_use
|
||||
nonlocal cpu_mem_use
|
||||
t_cuda, t_cpu = colo_tensor_mem_usage(t)
|
||||
cuda_mem_use += t_cuda
|
||||
cpu_mem_use += t_cpu
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.data_payload)
|
||||
address_set.add(self.data_payload.data_ptr())
|
||||
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.grad_payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.data)
|
||||
address_set.add(self.param.data.data_ptr())
|
||||
|
||||
if self.param.grad is not None and self.param.grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.param.grad)
|
||||
|
||||
return cuda_mem_use, cpu_mem_use
|
@@ -1,40 +0,0 @@
|
||||
import torch
|
||||
|
||||
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
assert tensor.requires_grad is False
|
||||
super().__init__(tensor, state)
|
||||
|
||||
# kept the shape, numel and dtype of the init tensor.
|
||||
self._origin_shape = tensor.shape
|
||||
self._origin_numel = tensor.numel()
|
||||
self._origin_dtype = tensor.dtype
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def origin_numel(self) -> int:
|
||||
return self._origin_numel
|
||||
|
||||
@property
|
||||
def origin_shape(self) -> int:
|
||||
return self._origin_shape
|
||||
|
||||
@property
|
||||
def is_sharded(self):
|
||||
return self._is_sharded
|
||||
|
||||
@is_sharded.setter
|
||||
def is_sharded(self, flag: bool):
|
||||
self._is_sharded = flag
|
@@ -7,9 +7,6 @@ from torch import Tensor, inf
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.utils import is_model_parallel_parameter
|
||||
|
||||
|
||||
def flatten(input_):
|
||||
return _flatten_dense_tensors(input_)
|
||||
|
Reference in New Issue
Block a user