[zero] reorganize zero/gemini folder structure (#3424)

* [zero] refactor low-level zero folder structure

* [zero] fix legacy zero import path

* [zero] fix legacy zero import path

* [zero] remove useless import

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor gemini folder structure

* [zero] refactor legacy zero import path

* [zero] fix test import path

* [zero] fix test

* [zero] fix circular import

* [zero] update import
This commit is contained in:
ver217 2023-04-04 13:48:16 +08:00 committed by GitHub
parent b09adff724
commit 26b7aac0be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
142 changed files with 1435 additions and 1404 deletions

View File

@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
import colossalai import colossalai
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.utils import get_static_torch_model
logger = get_dist_logger(__name__)
from .base import Strategy from .base import Strategy
from .ddp import DDPStrategy from .ddp import DDPStrategy
logger = get_dist_logger(__name__)
class ColossalAIStrategy(DDPStrategy): class ColossalAIStrategy(DDPStrategy):
""" """

View File

@ -4,8 +4,8 @@ from typing import Optional, Set
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.tensor_utils import free_storage
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.zero.legacy.gemini.tensor_utils import free_storage
from .region_manager import RegionManager from .region_manager import RegionManager
from .util import GlobalRuntimeInfo from .util import GlobalRuntimeInfo

View File

@ -1,7 +1,10 @@
from typing import List, Dict, Tuple from typing import Dict, List, Tuple
import torch import torch
from torch.fx import Node from torch.fx import Node
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class Region: class Region:
""" """
@ -52,15 +55,13 @@ class Region:
Map the parameters in the region to a contiguous memory space. Map the parameters in the region to a contiguous memory space.
""" """
self.fp16_data = torch.zeros( self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda')
self.param_num, dtype=torch.half, device='cuda')
offset = 0 offset = 0
for param in self.fp16_params: for param in self.fp16_params:
param.data = param.data.cuda() param.data = param.data.cuda()
p_num = param.data.numel() p_num = param.data.numel()
self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) self.fp16_data[offset:offset + p_num].copy_(param.data.flatten())
param.data = self.fp16_data[offset:offset + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape)
p_num].view(param.data.shape)
self.param_to_range[param] = (offset, offset + p_num) self.param_to_range[param] = (offset, offset + p_num)
offset += p_num offset += p_num
@ -141,4 +142,4 @@ class Region:
def __update_params_ptr(self) -> None: def __update_params_ptr(self) -> None:
for param in self.fp16_params: for param in self.fp16_params:
begin, end = self.param_to_range[param] begin, end = self.param_to_range[param]
param.data = self.fp16_data[begin:end].view(param.data.shape) param.data = self.fp16_data[begin:end].view(param.data.shape)

View File

@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.gemini.memory_tracer import MemStats
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import _convert_to_coloparam from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam
from colossalai.zero.gemini.memory_tracer import MemStats
from .plugin_base import Plugin from .plugin_base import Plugin

View File

@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss
from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.gradient_handler import BaseGradientHandler
from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule
from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively
class Engine: class Engine:

View File

@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule):
return self._move_to_device(mciro_batch_data) return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine): def pre_processing(self, engine):
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy import ShardedModelV2
# TODO: remove this after testing new zero with pipeline parallelism # TODO: remove this after testing new zero with pipeline parallelism
model = engine.model model = engine.model

View File

@ -1,9 +0,0 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .gemini_mgr import GeminiManager
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import TensorPlacementPolicyFactory
__all__ = [
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager',
'search_chunk_configuration'
]

View File

@ -29,13 +29,12 @@ from colossalai.engine.schedule import (
PipelineSchedule, PipelineSchedule,
get_tensor_shape, get_tensor_shape,
) )
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.zero import convert_to_zero_v2 from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2
from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
def get_default_parser(): def get_default_parser():

View File

@ -9,7 +9,7 @@ import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_decrator from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator
class MoeExperts(nn.Module): class MoeExperts(nn.Module):

View File

@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts
from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router
from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator
@no_shard_zero_decrator(is_replicated=True) @no_shard_zero_decrator(is_replicated=True)

View File

@ -1,15 +0,0 @@
from typing import Any
import torch
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
__all__ = ['GeminiAdamOptimizer']
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,5 +1,5 @@
from .data_parallel import ColoDDP, ZeroDDP from .data_parallel import ColoDDP
from .gemini_parallel import GeminiDDP
from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper
__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper'] __all__ = [
'ColoDDP',
]

View File

@ -1,31 +1,14 @@
import itertools
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Dict, Iterable, List, Optional, Set from typing import Iterable, Optional, Set
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import OrderedParamGenerator
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec from colossalai.utils import is_ddp_ignored
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
from .reducer import Reducer from .reducer import Reducer
from .utils import get_static_torch_model
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
def free_storage(data: torch.Tensor) -> None: def free_storage(data: torch.Tensor) -> None:
@ -189,507 +172,3 @@ class ColoDDP(torch.nn.Module):
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
return self.module.load_state_dict(state_dict, strict) return self.module.load_state_dict(state_dict, strict)
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
p.grad = None
def _pre_backward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Args:
param_list (_type_): a list of torch.nn.Parameters
only_rank_0 (_type_): _description_
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = ''
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
return
try:
with torch.no_grad():
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
elif strict:
missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
fp32_to_name = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
for name, buf in persistent_buffers.items():
if buf is not None:
load(name, buf, buf.copy_)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state",
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
if input_name not in local_state:
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()

View File

@ -1,63 +0,0 @@
from typing import Optional
import torch
from colossalai.gemini.chunk import init_chunk_manager
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.gemini.memory_tracer import MemStats
from .data_parallel import ZeroDDP
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -1,41 +1,16 @@
from typing import Tuple from .gemini import (
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
ZeroDDP,
ZeroOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper
import torch __all__ = [
import torch.nn as nn 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper',
'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model'
from colossalai.logging import get_dist_logger ]
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2
from ..nn.optimizer.zero_optimizer import ZeroOptimizer
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer']

View File

@ -0,0 +1,11 @@
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP, ZeroDDP
from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer
from .utils import get_static_torch_model
__all__ = [
'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP',
'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx'
]

View File

@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple
import torch import torch
from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .chunk import Chunk, ChunkFullError, TensorState
class ChunkManager: class ChunkManager:
""" """

View File

@ -5,9 +5,9 @@ import numpy as np
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:

View File

@ -5,10 +5,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.chunk.search_utils import search_chunk_configuration
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from .manager import ChunkManager
from .search_utils import search_chunk_configuration
def safe_div(a, b): def safe_div(a, b):
if a == 0: if a == 0:

View File

@ -3,10 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import InsertPostInitMethodToModuleSubClasses
# find named_params includes replica # find named_params includes replica
@ -89,6 +87,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
self._default_dist_spec = default_dist_spec self._default_dist_spec = default_dist_spec
def _register_colo_modules(self): def _register_colo_modules(self):
from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module
register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding()) register_colo_module(torch.nn.Embedding, ColoEmbedding())

View File

@ -0,0 +1,590 @@
import itertools
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Optional
import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ReplicaSpec
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook
from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
__all__ = [
'ZeroDDP',
'GeminiDDP',
]
class ZeroDDP(ColoDDP):
"""ZeRO DDP for ColoTensor.
Warning: Nested ZeroDDP is not supported now.
It is designed to be used with ChunkManager and GeminiManager.
For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``.
Args:
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
pin_memory (bool): Chunks on CPU Memory use pin-memory.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16.
Defaults to False.
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
"""
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False) -> None:
super().__init__(module, process_group=ColoProcessGroup())
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(gemini_manager)
self.fp32_params: List[ColoTensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self._cast_buffers()
self._logger = get_dist_logger()
if self.gemini_manager._premade_memstats_:
# build chunk in param runtime visited order.
param_order = self.gemini_manager.memstats()._param_runtime_order
else:
# build chunk in param initialized order.
# Note: in this way, it can not get filter unused params during runtime.
param_order = OrderedParamGenerator()
for p in module.parameters():
param_order.append(p)
self._init_chunks(param_order=param_order,
strict_ddp_mode=strict_ddp_mode,
cpu_offload=self.gemini_manager.policy_name != 'cuda',
pin_memory=pin_memory)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():
for p_name, p_var in m_var.named_parameters(recurse=False):
param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var
def _post_forward(self):
"""This function is only triggered for inference.
"""
access_list = list(self.chunk_manager.accessed_chunks)
# we need to scatter all accessed chunks and move them to their original places
for chunk in access_list:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
else:
assert chunk.can_release
self.chunk_manager.release_chunk(chunk)
first_param = next(iter(chunk.tensors_info))
self.chunk_manager.move_chunk(chunk, self.grads_device[first_param])
assert self.chunk_manager.accessed_mem == 0
# reset all recorded attributes
self.gemini_manager.reset_attributes()
def forward(self, *args, **kwargs):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
# scatter chunks in the inference mode
if not grad_flag:
self._post_forward()
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
p.grad = None
def _pre_backward(self):
# set a visit label for all parameters
# the label is used to check whether the parameter is correctly reduced
for param in self.param2name:
if not is_ddp_ignored(param):
setattr(param, "_gemini_reduced", False)
def _post_backward(self):
if self.chunk_manager.accessed_mem != 0:
error_params = ["Reduction failed at followed parameters:"]
for param in self.param2name:
if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"):
error_params.append(self.param2name[param])
error_str = "\n\t".join(error_params)
raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.",
"The most possible reason is that the model is not compatible with ZeroDDP.\n",
f"{error_str}")
self._setup_grads_ptr()
self._logger.debug(
f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}'
)
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad):
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter.")
self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE)
chunk.copy_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(chunk)
if reduced:
if chunk.is_gathered:
chunk.cuda_global_chunk.div_(chunk.pg_size)
else:
chunk.cuda_shard.div_(chunk.pg_size)
# check overflow elements
self.overflow_counter += chunk.has_inf_or_nan
# record l2 norm for gradient clipping
if chunk.l2_norm_flag:
chunk.set_l2_norm()
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
return empty_grad
def zero_grad(self, set_to_none: bool = False) -> None:
self.module.zero_grad(set_to_none=True)
def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
for tensor in chunk.get_tensors():
self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Warning: The non strict state dict would ignore the parameters if the tensors of the parameters
are shared with other parameters which have been included in the dictionary.
When you need to load the state dict, you should set the argument `strict` to False.
Returns:
dict:
a dictionary containing a whole state of the module
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0)
for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
"""
get param content from chunks.
Args:
param_list (_type_): a list of torch.nn.Parameters
only_rank_0 (_type_): _description_
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now."
# get copies of fp32 parameters in CPU
param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
record_parameter = param_to_save_data[fp32_p]
p_mapping[p] = record_parameter
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~torch.nn.Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~torch.nn.Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = ''
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join(
'"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func):
state_key = prefix + param_name
if state_key in state_dict:
input_param = state_dict[state_key]
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if input_param.shape != dest_tensor.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(state_key, input_param.shape,
dest_tensor.shape))
return
try:
with torch.no_grad():
copy_func(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(state_key, dest_tensor.size(),
input_param.size(), ex.args))
elif strict:
missing_keys.append(state_key)
def load_fp32_parameter(chunk_slice, data):
chunk_slice.copy_(data.flatten())
for name, param in self.named_parameters():
if is_ddp_ignored(param):
# deal with ddp ignored parameters
load(name, param, param.copy_)
fp32_to_name = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
if p is not None:
name = self.param2name[p]
fp32_to_name[fp32_p] = name
chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
parameter_name = fp32_to_name[tensor]
parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end]
load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice))
if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk)
elif chunk.cuda_shard is not None:
chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
else:
chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end])
del temp_chunk
for chunk_32 in chunk_list:
chunk_16 = chunk_32.paired_chunk
assert chunk_16 is not None
chunk_16.optim_update()
for name, buf in persistent_buffers.items():
if buf is not None:
load(name, buf, buf.copy_)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state",
torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
if input_name not in local_state:
unexpected_keys.append(key)
def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool):
ddp_pg = ColoProcessGroup()
for p in param_order.generate():
assert isinstance(p, ColoParameter)
# gather sharded parameters in the strict ddp mode
if strict_ddp_mode:
if not p.is_replicate():
p.set_dist_spec(ReplicaSpec())
p.set_process_group(pg=ddp_pg)
# ignore the parameters with no gradient
if not p.requires_grad:
self.set_params_to_ignore([p])
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
self.chunk_manager.register_tensor(tensor=p,
group_type='fp16_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.chunk_manager.register_tensor(tensor=fp32_p,
group_type='fp32_param',
config_key=dp_world_size,
cpu_offload=cpu_offload,
pin_memory=pin_memory)
self.fp16_params.append(p)
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self.chunk_manager.close_all_groups()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
chunk_16 = self.chunk_manager.get_chunk(p)
chunk_32 = self.chunk_manager.get_chunk(fp32_p)
chunk_32.init_pair(chunk_16)
# keep gathered chunks are in CUDA
if chunk_16.keep_gathered:
self.grads_device[p] = get_current_device()
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
class GeminiDDP(ZeroDDP):
def __init__(self,
module: torch.nn.Module,
device: torch.device,
placement_policy: str = "cpu",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
search_range_mb: int = 32,
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None) -> None:
"""
A torch.Module warpper using ZeRO-DP and Genimi.
ZeRO is for parallel. Gemini is for memory management.
WARNING: The class will modify the module inline!
Example:
model is initialized under the context of ColoInitContext
>>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda")
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
Args:
module (torch.nn.Module): the model to be wrapped.
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32.
hidden_dim (int, optional): the hidden dimension of DNN.
Users can provide this argument to speed up searching.
If users do not know this argument before training, it is ok. We will use a default value 1024.
min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte.
If the aggregate size of parameters is still samller than the minimum chunk size,
all parameters will be compacted into one small chunk.
memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer.
"""
# some ugly hotfix for the compatibility with Lightning
if search_range_mb is None:
search_range_mb = 32
chunk_manager = init_chunk_manager(model=module,
init_device=device,
hidden_dim=hidden_dim,
search_range_mb=search_range_mb,
min_chunk_size_mb=min_chunk_size_mb,
strict_ddp_flag=strict_ddp_mode)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode)

View File

@ -5,10 +5,10 @@ from typing import List
import torch import torch
from colossalai.gemini import TensorState
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
class TrainingPhase(Enum): class TrainingPhase(Enum):

View File

@ -4,10 +4,8 @@ from typing import List, Optional, Tuple
import torch import torch
from colossalai.gemini.chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import MemStats from .memory_tracer import ChunkMemStatsCollector, MemStats
from .memory_tracer import ChunkMemStatsCollector
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicyFactory

View File

@ -10,12 +10,15 @@ from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.nn.parallel.data_parallel import ZeroDDP
from colossalai.utils import disposable, get_current_device, is_ddp_ignored from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager
from .gemini_ddp import ZeroDDP
__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
@ -316,3 +319,10 @@ class ZeroOptimizer(ColossalaiOptimizer):
fake_params_list.append(fake_param) fake_params_list.append(fake_param)
group['params'] = fake_params_list group['params'] = fake_params_list
class GeminiAdamOptimizer(ZeroOptimizer):
def __init__(self, model: torch.nn.Module, **defaults: Any) -> None:
optimizer = HybridAdam(model.parameters(), **defaults)
super().__init__(optimizer, model, **defaults)

View File

@ -1,10 +1,10 @@
from typing import Optional from typing import Optional
from colossalai.gemini.chunk import ChunkManager
from colossalai.gemini.memory_tracer import MemStats
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import ChunkManager
from .memory_stats import MemStats
from .memstats_collector import MemStatsCollector from .memstats_collector import MemStatsCollector

View File

@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional
import torch import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator from .param_runtime_order import OrderedParamGenerator
class MemStats(object): class MemStats(object):

View File

@ -1,12 +1,7 @@
import time import time
from typing import List, Optional from typing import Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from .memory_monitor import SyncCudaMemoryMonitor
from .memory_stats import MemStats from .memory_stats import MemStats
@ -49,7 +44,7 @@ class MemStatsCollector:
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}" f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data return next_non_model_data
@ -75,6 +70,8 @@ class MemStatsCollector:
Sampling model data statistics. Sampling model data statistics.
""" """
if self._start_flag and not self.use_outside_memstats: if self._start_flag and not self.use_outside_memstats:
from colossalai.zero.legacy.gemini import StatefulTensor
# The following code work for ZeroInitContext, which is deprecated in v0.1.12 # The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem) self._memstats.record_max_cuda_model_data(cuda_mem)

View File

@ -1,9 +1,14 @@
import torch.nn import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import (
GradMemStats,
GradMemTracerHook,
ParamMemTracerHook,
)
from .memory_stats import MemStats
__all__ = ['RuntimeMemTracer'] __all__ = ['RuntimeMemTracer']

View File

@ -6,7 +6,7 @@ from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta
from colossalai.gemini.chunk import ChunkManager from colossalai.zero.gemini.chunk import ChunkManager
if is_compatible_with_meta(): if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor from colossalai.fx.profiler import MetaTensor

View File

@ -5,11 +5,12 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.gemini.chunk import Chunk, ChunkManager
from colossalai.gemini.memory_tracer import ChunkMemStatsCollector
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector
class PlacementPolicy(ABC): class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False

View File

@ -6,9 +6,10 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.gemini.chunk import Chunk
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .chunk import Chunk
def get_temp_total_chunk_on_cuda(chunk: Chunk): def get_temp_total_chunk_on_cuda(chunk: Chunk):
if chunk.is_gathered: if chunk.is_gathered:
@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model,
Returns: Returns:
torch.nn.Module: a static torch model used for saving checkpoints or numeric checks torch.nn.Module: a static torch model used for saving checkpoints or numeric checks
""" """
from colossalai.nn.parallel import ZeroDDP from colossalai.zero.gemini.gemini_ddp import ZeroDDP
assert isinstance(zero_ddp_model, ZeroDDP) assert isinstance(zero_ddp_model, ZeroDDP)
state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0)

View File

@ -0,0 +1,44 @@
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.logging import get_dist_logger
from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator
from .sharded_model import ShardedModelV2
from .sharded_optim import ShardedOptimizerV2
def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config,
optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]:
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer_config: Your optimizer object
:type optimizer_config: :class:`dict`
:return: (model, optimizer)
:rtype: Tuple
"""
logger = get_dist_logger('convert_to_zero_v2')
logger.info(f'optimizer_config is {optimizer_config}', ranks=[0])
if optimizer_config is None:
optimizer_config = dict()
logger.info(f'model_config is {model_config}', ranks=[0])
if model_config is None:
model_config = dict()
zero_model = ShardedModelV2(model, **model_config)
zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config)
return zero_model, zero_optimizer
__all__ = [
'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context',
'no_shard_zero_decrator'
]

View File

@ -0,0 +1,9 @@
from .ophooks import BaseOpHook, register_ophooks_recursively
from .stateful_tensor import StatefulTensor
from .stateful_tensor_mgr import StatefulTensorMgr
from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy
__all__ = [
'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy',
'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook'
]

View File

@ -1,4 +1,5 @@
import torch import torch
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from . import BaseOpHook from . import BaseOpHook

View File

@ -5,9 +5,9 @@ from typing import List
import torch import torch
from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.gemini.tensor_utils import alloc_storage, free_storage
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor
from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage
class TrainingPhase(Enum): class TrainingPhase(Enum):

View File

@ -1,9 +1,9 @@
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional, Union
import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager import torch
from .gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor): def sizeof_tensor(tensor: torch.Tensor):
@ -19,7 +19,7 @@ class TensorState(Enum):
class StatefulTensor(object): class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states. """A Structure stores a Torch Tensor and labeled states.
Inspired from the paper: Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management

View File

@ -1,13 +1,16 @@
import functools import functools
import torch
import types import types
from colossalai.utils.cuda import get_current_device
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
from time import time from time import time
from typing import List
import torch
from colossalai.logging import get_dist_logger
from colossalai.utils.cuda import get_current_device
from .stateful_tensor import StatefulTensor, TensorState
from .tensor_placement_policy import TensorPlacementPolicy
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class StatefulTensorMgr(object): class StatefulTensorMgr(object):

View File

@ -5,11 +5,12 @@ from typing import List, Optional, Type
import torch import torch
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from .stateful_tensor import StatefulTensor
from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
class TensorPlacementPolicy(ABC): class TensorPlacementPolicy(ABC):

View File

@ -1,6 +1,8 @@
from typing import Tuple, Union
import torch import torch
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple from .stateful_tensor import StatefulTensor
def is_storage_empty(tensor: torch.Tensor) -> bool: def is_storage_empty(tensor: torch.Tensor) -> bool:

View File

@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.legacy.sharded_param import ShardedParamV2
@dataclass @dataclass

View File

@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from typing import List, Optional from typing import List, Optional
import torch.distributed as dist import torch.distributed as dist
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class BaseShardStrategy(ABC): class BaseShardStrategy(ABC):

View File

@ -2,17 +2,18 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from torch._utils import _flatten_dense_tensors as flatten from torch._utils import _flatten_dense_tensors as flatten
from colossalai.utils import get_current_device
from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
from .tensor_shard_strategy import TensorShardStrategy from .tensor_shard_strategy import TensorShardStrategy
class BucketTensorShardStrategy(TensorShardStrategy): class BucketTensorShardStrategy(TensorShardStrategy):
"""Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together,
which will fully utilize network bandwidth. which will fully utilize network bandwidth.
It is especially useful when sub-module contains bias, It is especially useful when sub-module contains bias,
since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small).
""" """

View File

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
from typing import Tuple from typing import Tuple
import torch
def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor.""" """Return the local shard of a full tensor."""

View File

@ -2,11 +2,12 @@ from typing import List, Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils.commons import get_shard from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.legacy.shard_utils.commons import get_shard
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy):
Args: Args:
t (ShardedTensor): a tensor to be sharded. t (ShardedTensor): a tensor to be sharded.
process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards.
Defaults to None. Defaults to None.
""" """
if t.is_sharded: if t.is_sharded:

View File

@ -1,3 +1,3 @@
from .sharded_model_v2 import ShardedModelV2 from .sharded_model_v2 import ShardedModelV2
__all__ = ['ShardedModelV2'] __all__ = ['ShardedModelV2']

View File

@ -1,9 +1,9 @@
from typing import Any, Callable, List, Tuple from typing import Any, Callable, List, Tuple, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Union
from colossalai.gemini.stateful_tensor import StatefulTensor from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float: def get_gradient_predivide_factor(world_size: int) -> float:

View File

@ -13,19 +13,18 @@ from torch.nn.parameter import Parameter
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.gemini.ophooks import register_ophooks_recursively
from colossalai.gemini.paramhooks import BaseParamHookMgr
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import disposable, get_current_device from colossalai.utils import disposable, get_current_device
from colossalai.utils.memory import colo_device_memory_capacity from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively
from colossalai.zero.utils import ZeroHook from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer
from ._utils import ( from ._utils import (
cast_float_arguments, cast_float_arguments,
@ -35,6 +34,7 @@ from ._utils import (
free_storage, free_storage,
get_gradient_predivide_factor, get_gradient_predivide_factor,
) )
from .zero_hook import ZeroHook
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX

View File

@ -1,8 +1,9 @@
import torch
from colossalai.zero.sharded_model import ShardedModelV2
import copy import copy
import torch
from colossalai.zero.legacy.sharded_model import ShardedModelV2
def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module):
""" """

View File

@ -3,14 +3,14 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import OPHOOKS from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.gemini.memory_tracer import MemStatsCollector
from colossalai.zero.legacy.gemini.ophooks import BaseOpHook
from colossalai.zero.legacy.gemini.stateful_tensor import TensorState
from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
@OPHOOKS.register_module @OPHOOKS.register_module

View File

@ -0,0 +1,3 @@
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2']

View File

@ -14,13 +14,13 @@ from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.legacy.sharded_model import ShardedModelV2
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32
class OptimState(Enum): class OptimState(Enum):

View File

@ -0,0 +1,4 @@
from .sharded_param import ShardedParamV2
from .sharded_tensor import ShardedTensor
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@ -1,9 +1,11 @@
from typing import List, Optional, Tuple
import torch import torch
from typing import Optional, Tuple
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from typing import List from .sharded_tensor import ShardedTensor
EMPTY_TENSOR_DICT = {} EMPTY_TENSOR_DICT = {}

View File

@ -1,5 +1,6 @@
import torch import torch
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor): class ShardedTensor(StatefulTensor):

View File

@ -0,0 +1,3 @@
from .low_level_optim import LowLevelZeroOptimizer
__all__ = ['LowLevelZeroOptimizer']

View File

@ -1,4 +0,0 @@
from .low_level_optim import LowLevelZeroOptimizer
from .sharded_optim_v2 import ShardedOptimizerV2
__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer']

View File

@ -1,4 +0,0 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
__all__ = ['ShardedTensor', 'ShardedParamV2']

View File

@ -1,3 +0,0 @@
from .zero_hook import ZeroHook
__all__ = ['ZeroHook']

View File

@ -4,7 +4,7 @@ from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from .gemini_parallel import GeminiDDP from .gemini import GeminiDDP
def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None):
@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module,
config_dict['max_scale'] = max_scale config_dict['max_scale'] = max_scale
if zero_stage in [1, 2]: if zero_stage in [1, 2]:
from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer from colossalai.zero.low_level import LowLevelZeroOptimizer
config_dict['partition_grad'] = zero_stage == 2 config_dict['partition_grad'] = zero_stage == 2
config_dict['clip_grad_norm'] = max_norm config_dict['clip_grad_norm'] = max_norm
return LowLevelZeroOptimizer(optimizer, **config_dict) return LowLevelZeroOptimizer(optimizer, **config_dict)
else: else:
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer
config_dict['clipping_norm'] = max_norm config_dict['clipping_norm'] = max_norm
return ZeroOptimizer(optimizer, model, **config_dict) return ZeroOptimizer(optimizer, model, **config_dict)

View File

@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
``` ```

View File

@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
``` ```

View File

@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
import colossalai import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx from colossalai.zero import ColoInitContext, post_process_colo_init_ctx
path = "/data/scratch/diffuser/stable-diffusion-v1-4" path = "/data/scratch/diffuser/stable-diffusion-v1-4"

View File

@ -21,10 +21,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()

View File

@ -23,10 +23,9 @@ import colossalai
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()

View File

@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext
def set_seed(seed): def set_seed(seed):

View File

@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext
def init_1d_row_for_linear_weight_spec(model, world_size: int): def init_1d_row_for_linear_weight_spec(model, world_size: int):

View File

@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__

View File

@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__

View File

@ -34,12 +34,9 @@ from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import GeminiDDP
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
@ -179,13 +176,15 @@ def main():
# build model # build model
if args.model_name_or_path is None: if args.model_name_or_path is None:
logger.info("Train a new model from scratch", ranks=[0]) logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half, with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec, default_dist_spec=default_dist_spec,
default_pg=shard_pg): default_pg=shard_pg):
model = OPTForCausalLM(config) model = OPTForCausalLM(config)
else: else:
logger.info("Finetune a pre-trained model", ranks=[0]) logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev, dtype=torch.half, with ColoInitContext(device=init_dev,
dtype=torch.half,
default_dist_spec=default_dist_spec, default_dist_spec=default_dist_spec,
default_pg=shard_pg): default_pg=shard_pg):
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
@ -198,8 +197,11 @@ def main():
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
PLACEMENT_POLICY = 'cpu' PLACEMENT_POLICY = 'cpu'
model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, model = GeminiDDP(model,
pin_memory=True, strict_ddp_mode=args.shardinit) device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
strict_ddp_mode=args.shardinit)
optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0)
SEQ_LEN = 1024 SEQ_LEN = 1024

View File

@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device from colossalai.utils import MultiTimer, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
# constants # constants
@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
return model return model
## Parameter Sharding Strategies for Tensor Parallelism # Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec) param.set_tensor_spec(*spec)
@ -232,7 +230,7 @@ if args.distplan == "colossalai":
tensor_parallelize(model, pg) tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement) model = gemini_zero_dpp(model, pg, args.placement)
#optimizer # optimizer
#optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5)
optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5)

View File

@ -1,69 +1,67 @@
import colossalai
import math import math
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import colossalai.nn as col_nn
from arguments import parse_args
from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt
from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args
from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer
from utils.logger import Logger
from evaluation import evaluate
from loss import LossForPretraining
from colossalai.zero.init_ctx import ZeroInitContext
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_optim import ShardedOptimizerV2
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from tqdm import tqdm
import os import os
import time import time
from functools import partial from functools import partial
import torch
from arguments import parse_args
from evaluation import evaluate
from loss import LossForPretraining
from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider
from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt
from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator
from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables
from utils.logger import Logger
from colossalai.gemini import ChunkManager, GeminiManager import colossalai
from colossalai.utils.model.colo_init_context import ColoInitContext import colossalai.nn as col_nn
from colossalai.utils import get_current_device from colossalai.context import ParallelMode
from colossalai.nn.parallel import ZeroDDP from colossalai.core import global_context as gpc
from colossalai.zero import ZeroOptimizer
from colossalai.tensor import ProcessGroup
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device
from colossalai.zero import ZeroOptimizer
from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager
from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext
from colossalai.zero.legacy.shard_utils import TensorShardStrategy
def main(): def main():
args = parse_args() args = parse_args()
launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug)
if args.vscode_debug: if args.vscode_debug:
colossalai.launch(config={}, colossalai.launch(config={},
rank=args.rank, rank=args.rank,
world_size=args.world_size, world_size=args.world_size,
host=args.host, host=args.host,
port=args.port, port=args.port,
backend=args.backend) backend=args.backend)
args.local_rank = -1 args.local_rank = -1
args.log_interval = 1 args.log_interval = 1
else: else:
colossalai.launch_from_torch(args.colossal_config) #args.colossal_config colossalai.launch_from_torch(args.colossal_config) # args.colossal_config
args.local_rank = int(os.environ["LOCAL_RANK"]) args.local_rank = int(os.environ["LOCAL_RANK"])
logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + logger.info(
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' +
f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}'
)
log_args(logger, args) log_args(logger, args)
args.tokenizer = tokenizer args.tokenizer = tokenizer
args.logger = logger args.logger = logger
set_global_variables(launch_time, args.tensorboard_path) set_global_variables(launch_time, args.tensorboard_path)
use_zero = hasattr(gpc.config, 'zero') use_zero = hasattr(gpc.config, 'zero')
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
@ -71,8 +69,8 @@ def main():
if use_zero: if use_zero:
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy,
shard_param=True): shard_param=True):
config, model, numel = get_model(args, logger) config, model, numel = get_model(args, logger)
# model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True)
else: else:
@ -82,9 +80,10 @@ def main():
os.mkdir(os.path.join(args.ckpt_path, launch_time)) os.mkdir(os.path.join(args.ckpt_path, launch_time))
logger.info(f'Model numel: {numel}') logger.info(f'Model numel: {numel}')
get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) # len(dataloader)
steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size
total_steps = steps_per_epoch * args.epoch total_steps = steps_per_epoch * args.epoch
# build optimizer and lr_scheduler # build optimizer and lr_scheduler
@ -98,18 +97,23 @@ def main():
o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1
optimizer = get_optimizer(model, lr=args.lr) optimizer = get_optimizer(model, lr=args.lr)
optimizer.load_state_dict(o_l_state_dict['optimizer']) optimizer.load_state_dict(o_l_state_dict['optimizer'])
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] # o_l_state_dict['lr_scheduler']['last_epoch']
lr_scheduler = get_lr_scheduler(optimizer,
total_steps=total_steps,
last_epoch=o_l_state_dict['lr_scheduler']['last_epoch'])
for state in optimizer.state.values(): for state in optimizer.state.values():
for k, v in state.items(): for k, v in state.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}")
# if you want delete the above three code, have to move the model to gpu, because in optimizer.step() # if you want delete the above three code, have to move the model to gpu, because in optimizer.step()
lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler'])
start_epoch = o_l_state_dict['epoch'] start_epoch = o_l_state_dict['epoch']
start_shard = o_l_state_dict['shard'] + 1 start_shard = o_l_state_dict['shard'] + 1
# global_step = o_l_state_dict['global_step'] + 1 # global_step = o_l_state_dict['global_step'] + 1
logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') logger.info(
f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}'
)
else: else:
optimizer = get_optimizer(model, lr=args.lr) optimizer = get_optimizer(model, lr=args.lr)
lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1)
@ -124,12 +128,11 @@ def main():
# initialize with colossalai # initialize with colossalai
engine, _, _, lr_scheduelr = colossalai.initialize(model=model, engine, _, _, lr_scheduelr = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
logger.info(get_mem_info(prefix='After init model, ')) logger.info(get_mem_info(prefix='After init model, '))
best_loss = None best_loss = None
eval_loss = 0 eval_loss = 0
@ -146,13 +149,16 @@ def main():
dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard)
# pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) iterator_data = tqdm(enumerate(dataset_iterator),
total=(total_length // args.train_micro_batch_size_per_gpu // world_size),
colour='cyan',
smoothing=1)
else: else:
iterator_data = enumerate(dataset_iterator) iterator_data = enumerate(dataset_iterator)
engine.train() engine.train()
for step, batch_data in iterator_data: for step, batch_data in iterator_data:
# batch_data = pretrain_dataset_provider.get_batch(batch_index) # batch_data = pretrain_dataset_provider.get_batch(batch_index)
input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}")
@ -162,7 +168,7 @@ def main():
# nsp_label = batch_data[5].cuda() # nsp_label = batch_data[5].cuda()
output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
loss = engine.criterion(output.logits, mlm_label) loss = engine.criterion(output.logits, mlm_label)
pretrain_dataset_provider.prefetch_batch() pretrain_dataset_provider.prefetch_batch()
@ -172,14 +178,15 @@ def main():
engine.step() engine.step()
lr_scheduelr.step() lr_scheduelr.step()
engine.zero_grad() engine.zero_grad()
global_step += 1 global_step += 1
if global_step % args.log_interval == 0 and global_step != 0 \ if global_step % args.log_interval == 0 and global_step != 0 \
and torch.distributed.get_rank() == 0: and torch.distributed.get_rank() == 0:
elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time = timers('interval_time').elapsed(reset=False)
elapsed_time_per_iteration = elapsed_time / global_step elapsed_time_per_iteration = elapsed_time / global_step
samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(
numel, args, config, elapsed_time, global_step, world_size)
cur_loss = train_loss / args.log_interval cur_loss = train_loss / args.log_interval
current_lr = lr_scheduelr.get_last_lr()[0] current_lr = lr_scheduelr.get_last_lr()[0]
@ -189,12 +196,13 @@ def main():
if args.wandb: if args.wandb:
tensorboard_log = get_tensorboard_writer() tensorboard_log = get_tensorboard_writer()
tensorboard_log.log_train({ tensorboard_log.log_train(
'lr': current_lr, {
'loss': cur_loss, 'lr': current_lr,
'ppl': math.exp(cur_loss), 'loss': cur_loss,
'mins_batch': elapsed_time_per_iteration 'ppl': math.exp(cur_loss),
}, global_step) 'mins_batch': elapsed_time_per_iteration
}, global_step)
train_loss = 0 train_loss = 0
@ -202,12 +210,14 @@ def main():
logger.info('*' * 100) logger.info('*' * 100)
eval_loss += evaluate(engine, args, logger, global_step) eval_loss += evaluate(engine, args, logger, global_step)
save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) save_ckpt(engine.model, optimizer, lr_scheduelr,
os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch,
shard, global_step)
eval_loss /= len(os.listdir(args.data_path_prefix)) eval_loss /= len(os.listdir(args.data_path_prefix))
logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ logger.info(
f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins'
+ f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}')
logger.info('-' * 100) logger.info('-' * 100)
if args.wandb and torch.distributed.get_rank() == 0: if args.wandb and torch.distributed.get_rank() == 0:
tensorboard_log = get_tensorboard_writer() tensorboard_log = get_tensorboard_writer()

View File

@ -30,24 +30,13 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
@ -61,6 +50,15 @@ from transformers import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())

View File

@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.testing import parameterize from colossalai.testing import parameterize
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed from tests.test_tensor.common_utils import set_seed

View File

@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.tensor.process_group import ProcessGroup from colossalai.tensor.process_group import ProcessGroup
from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing import assert_close, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper
class MLP(torch.nn.Module): class MLP(torch.nn.Module):

View File

@ -10,14 +10,14 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration from colossalai.nn.parallel import ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.nn.parallel import ColoDDP, ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
def set_seed(seed): def set_seed(seed):

View File

@ -1,18 +1,19 @@
import copy import copy
from collections import OrderedDict
from functools import partial
import pytest import pytest
import colossalai
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device import colossalai
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ColoDDP from colossalai.nn.parallel import ColoDDP
from collections import OrderedDict from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.tensor import ProcessGroup, ColoParameter from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext
from tests.components_to_test.registry import non_distributed_component_funcs
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):

View File

@ -1,73 +1,73 @@
import pytest import pytest
import torch import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState
@pytest.mark.dist @pytest.mark.dist
def test_gemini_manager(): def test_gemini_manager():
# reset the manager, in case that there exists memory information left # reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR manager = StatefulTensor.GST_MGR
manager.reset() manager.reset()
# occupation 8 # occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60 # occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28 # occupation 28
t1 = torch.empty(7, device='cuda') t1 = torch.empty(7, device='cuda')
# occupation 12 # occupation 12
t2 = torch.empty(3, device='cpu') t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE) st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60 assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36 assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60 assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8 assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2) st4.payload_reset(t2)
st3.payload_reset(t2) st3.payload_reset(t2)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84 assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8 assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72 assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8 assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu')) st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu')) st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0)) st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80 assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12 assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80 assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE) st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE) st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD) st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4 assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80 assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12 assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12 assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__': if __name__ == '__main__':
test_gemini_manager() test_gemini_manager()

Some files were not shown because too many files have changed in this diff Show More