diff --git a/applications/Chat/coati/trainer/strategies/colossalai.py b/applications/Chat/coati/trainer/strategies/colossalai.py index 521c53640..ba85ba76d 100644 --- a/applications/Chat/coati/trainer/strategies/colossalai.py +++ b/applications/Chat/coati/trainer/strategies/colossalai.py @@ -14,17 +14,16 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase import colossalai from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import CPUAdam, HybridAdam -from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.tensor import ProcessGroup, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - -logger = get_dist_logger(__name__) +from colossalai.zero import ColoInitContext, ZeroDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.utils import get_static_torch_model from .base import Strategy from .ddp import DDPStrategy +logger = get_dist_logger(__name__) + class ColossalAIStrategy(DDPStrategy): """ diff --git a/colossalai/auto_parallel/offload/base_offload_module.py b/colossalai/auto_parallel/offload/base_offload_module.py index 3a32f0722..d0c328e13 100644 --- a/colossalai/auto_parallel/offload/base_offload_module.py +++ b/colossalai/auto_parallel/offload/base_offload_module.py @@ -4,8 +4,8 @@ from typing import Optional, Set import torch import torch.nn as nn -from colossalai.gemini.tensor_utils import free_storage from colossalai.nn.parallel.data_parallel import _cast_float +from colossalai.zero.legacy.gemini.tensor_utils import free_storage from .region_manager import RegionManager from .util import GlobalRuntimeInfo diff --git a/colossalai/auto_parallel/offload/region.py b/colossalai/auto_parallel/offload/region.py index e6907cc4b..9a2f558c3 100644 --- a/colossalai/auto_parallel/offload/region.py +++ b/colossalai/auto_parallel/offload/region.py @@ -1,7 +1,10 @@ -from typing import List, Dict, Tuple +from typing import Dict, List, Tuple + import torch from torch.fx import Node -from colossalai.gemini.tensor_utils import alloc_storage, free_storage + +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage + class Region: """ @@ -52,15 +55,13 @@ class Region: Map the parameters in the region to a contiguous memory space. """ - self.fp16_data = torch.zeros( - self.param_num, dtype=torch.half, device='cuda') + self.fp16_data = torch.zeros(self.param_num, dtype=torch.half, device='cuda') offset = 0 for param in self.fp16_params: param.data = param.data.cuda() p_num = param.data.numel() self.fp16_data[offset:offset + p_num].copy_(param.data.flatten()) - param.data = self.fp16_data[offset:offset + - p_num].view(param.data.shape) + param.data = self.fp16_data[offset:offset + p_num].view(param.data.shape) self.param_to_range[param] = (offset, offset + p_num) offset += p_num @@ -141,4 +142,4 @@ class Region: def __update_params_ptr(self) -> None: for param in self.fp16_params: begin, end = self.param_to_range[param] - param.data = self.fp16_data[begin:end].view(param.data.shape) \ No newline at end of file + param.data = self.fp16_data[begin:end].view(param.data.shape) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index c3c9d007d..3c6e539ba 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -14,12 +14,12 @@ from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator -from colossalai.gemini.memory_tracer import MemStats from colossalai.interface import ModelWrapper, OptimizerWrapper -from colossalai.nn.parallel import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.tensor.colo_parameter import ColoParameter from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import _convert_to_coloparam +from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper +from colossalai.zero.gemini.colo_init_context import _convert_to_coloparam +from colossalai.zero.gemini.memory_tracer import MemStats from .plugin_base import Plugin diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 59d8e1058..ff8979d82 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -10,8 +10,8 @@ from torch.nn.modules.loss import _Loss from colossalai.engine.gradient_handler import BaseGradientHandler from colossalai.engine.schedule import BaseSchedule, InterleavedPipelineSchedule, NonPipelineSchedule, PipelineSchedule -from colossalai.gemini.ophooks import BaseOpHook, register_ophooks_recursively from colossalai.logging import get_dist_logger +from colossalai.zero.legacy.gemini import BaseOpHook, register_ophooks_recursively class Engine: diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 712ae8242..38175fe09 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -157,7 +157,7 @@ class PipelineSchedule(BaseSchedule): return self._move_to_device(mciro_batch_data) def pre_processing(self, engine): - from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 + from colossalai.zero.legacy import ShardedModelV2 # TODO: remove this after testing new zero with pipeline parallelism model = engine.model diff --git a/colossalai/gemini/__init__.py b/colossalai/gemini/__init__.py deleted file mode 100644 index 7a5a44ebb..000000000 --- a/colossalai/gemini/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration -from .gemini_mgr import GeminiManager -from .stateful_tensor_mgr import StatefulTensorMgr -from .tensor_placement_policy import TensorPlacementPolicyFactory - -__all__ = [ - 'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', - 'search_chunk_configuration' -] diff --git a/colossalai/initialize.py b/colossalai/initialize.py index f3719dcb4..5d3f3e553 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -29,13 +29,12 @@ from colossalai.engine.schedule import ( PipelineSchedule, get_tensor_shape, ) -from colossalai.gemini.ophooks import BaseOpHook from colossalai.logging import get_dist_logger from colossalai.nn.optimizer.colossalai_optimizer import ColossalaiOptimizer from colossalai.utils import get_current_device, is_using_ddp, is_using_pp, is_using_sequence, sync_model_param from colossalai.utils.moe import sync_moe_model_param -from colossalai.zero import convert_to_zero_v2 -from colossalai.zero.sharded_optim.sharded_optim_v2 import ShardedOptimizerV2 +from colossalai.zero.legacy import ShardedOptimizerV2, convert_to_zero_v2 +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook def get_default_parser(): diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index 4fb9ad332..2e5d9e6e7 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -9,7 +9,7 @@ import torch.nn as nn from colossalai.context import ParallelMode, seed from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device -from colossalai.zero.init_ctx import no_shard_zero_decrator +from colossalai.zero.legacy.init_ctx import no_shard_zero_decrator class MoeExperts(nn.Module): diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index 0969eb818..b90d1f0bf 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -18,7 +18,7 @@ from colossalai.nn.layer.moe.experts import Experts, MoeExperts from colossalai.nn.layer.moe.routers import MoeRouter, Top1Router, Top2Router from colossalai.nn.layer.moe.utils import NormalNoiseGenerator, UniformNoiseGenerator from colossalai.utils import get_current_device -from colossalai.zero.init_ctx import no_shard_zero_context, no_shard_zero_decrator +from colossalai.zero.legacy.init_ctx import no_shard_zero_context, no_shard_zero_decrator @no_shard_zero_decrator(is_replicated=True) diff --git a/colossalai/nn/optimizer/gemini_optimizer.py b/colossalai/nn/optimizer/gemini_optimizer.py deleted file mode 100644 index 31d161612..000000000 --- a/colossalai/nn/optimizer/gemini_optimizer.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Any - -import torch - -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer - -__all__ = ['GeminiAdamOptimizer'] - - -class GeminiAdamOptimizer(ZeroOptimizer): - - def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: - optimizer = HybridAdam(model.parameters(), **defaults) - super().__init__(optimizer, model, **defaults) diff --git a/colossalai/nn/parallel/__init__.py b/colossalai/nn/parallel/__init__.py index 2afc8f18c..17e010f47 100644 --- a/colossalai/nn/parallel/__init__.py +++ b/colossalai/nn/parallel/__init__.py @@ -1,5 +1,5 @@ -from .data_parallel import ColoDDP, ZeroDDP -from .gemini_parallel import GeminiDDP -from .zero_wrapper import zero_model_wrapper, zero_optim_wrapper +from .data_parallel import ColoDDP -__all__ = ['ColoDDP', 'ZeroDDP', 'GeminiDDP', 'zero_model_wrapper', 'zero_optim_wrapper'] +__all__ = [ + 'ColoDDP', +] diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a9d001bd0..f839d6b28 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -1,31 +1,14 @@ -import itertools from collections import OrderedDict from functools import partial -from typing import Dict, Iterable, List, Optional, Set +from typing import Iterable, Optional, Set import torch import torch.distributed as dist -import torch.nn as nn -from colossalai.gemini.chunk import Chunk, ChunkManager, TensorState -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer import OrderedParamGenerator -from colossalai.logging import get_dist_logger -from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda from colossalai.tensor import ProcessGroup as ColoProcessGroup -from colossalai.tensor import ReplicaSpec -from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec -from colossalai.tensor.param_op_hook import ColoParamOpHookManager -from colossalai.utils import get_current_device, is_ddp_ignored -from colossalai.zero.utils.gemini_hook import GeminiZeROHook +from colossalai.utils import is_ddp_ignored from .reducer import Reducer -from .utils import get_static_torch_model - -try: - from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys -except ImportError: - _EXTRA_STATE_KEY_SUFFIX = '_extra_state' def free_storage(data: torch.Tensor) -> None: @@ -189,507 +172,3 @@ class ColoDDP(torch.nn.Module): def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): return self.module.load_state_dict(state_dict, strict) - - -class ZeroDDP(ColoDDP): - """ZeRO DDP for ColoTensor. - Warning: Nested ZeroDDP is not supported now. - It is designed to be used with ChunkManager and GeminiManager. - For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. - - Args: - module (torch.nn.Module): Module to apply ZeRO-DP. - gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. - For more details, see the API reference of ``GeminiManager``. - pin_memory (bool): Chunks on CPU Memory use pin-memory. - force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. - Defaults to False. - strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. - Defaults to False. Users can set it to True, when they clearly know that they only need DDP. - """ - - def __init__(self, - module: torch.nn.Module, - gemini_manager: GeminiManager, - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False) -> None: - super().__init__(module, process_group=ColoProcessGroup()) - self.gemini_manager = gemini_manager - self.chunk_manager: ChunkManager = gemini_manager.chunk_manager - self.force_outputs_fp32 = force_outputs_fp32 - self.param_op_hook = GeminiZeROHook(gemini_manager) - self.fp32_params: List[ColoTensor] = list() - self.fp16_params: List[ColoParameter] = list() - self.overflow_counter = 0 - self.grads_device: Dict[torch.Tensor, torch.device] = dict() - self.param2name: Dict[nn.Parameter, str] = dict() - self.name2param: Dict[str, nn.Parameter] = dict() - - self._cast_buffers() - self._logger = get_dist_logger() - - if self.gemini_manager._premade_memstats_: - # build chunk in param runtime visited order. - param_order = self.gemini_manager.memstats()._param_runtime_order - else: - # build chunk in param initialized order. - # Note: in this way, it can not get filter unused params during runtime. - param_order = OrderedParamGenerator() - for p in module.parameters(): - param_order.append(p) - - self._init_chunks(param_order=param_order, - strict_ddp_mode=strict_ddp_mode, - cpu_offload=self.gemini_manager.policy_name != 'cuda', - pin_memory=pin_memory) - - for name, param in module.named_parameters(): - self.param2name[param] = name - for m_name, m_var in module.named_modules(): - for p_name, p_var in m_var.named_parameters(recurse=False): - param_name = m_name + '.' + p_name if m_name else p_name - self.name2param[param_name] = p_var - - def _post_forward(self): - """This function is only triggered for inference. - """ - access_list = list(self.chunk_manager.accessed_chunks) - # we need to scatter all accessed chunks and move them to their original places - for chunk in access_list: - if chunk.keep_gathered: - self.chunk_manager.fake_release_chunk(chunk) - else: - assert chunk.can_release - self.chunk_manager.release_chunk(chunk) - first_param = next(iter(chunk.tensors_info)) - self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) - assert self.chunk_manager.accessed_mem == 0 - # reset all recorded attributes - self.gemini_manager.reset_attributes() - - def forward(self, *args, **kwargs): - # check whether we are in a inference mode - grad_flag = torch.is_grad_enabled() - if not grad_flag: - assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( - ), "You should run a completed iteration as your warmup iter" - - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) - self.module.zero_grad(set_to_none=True) - self.gemini_manager.pre_iter(*args) - with ColoParamOpHookManager.use_hooks(self.param_op_hook): - outputs = self.module(*args, **kwargs) - # scatter chunks in the inference mode - if not grad_flag: - self._post_forward() - - if self.force_outputs_fp32: - return _cast_float(outputs, torch.float) - return outputs - - def _setup_grads_ptr(self): - for p in self.module.parameters(): - if is_ddp_ignored(p): - continue - p.grad = None - - def _pre_backward(self): - # set a visit label for all parameters - # the label is used to check whether the parameter is correctly reduced - for param in self.param2name: - if not is_ddp_ignored(param): - setattr(param, "_gemini_reduced", False) - - def _post_backward(self): - if self.chunk_manager.accessed_mem != 0: - error_params = ["Reduction failed at followed parameters:"] - for param in self.param2name: - if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): - error_params.append(self.param2name[param]) - error_str = "\n\t".join(error_params) - raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", - "The most possible reason is that the model is not compatible with ZeroDDP.\n", - f"{error_str}") - self._setup_grads_ptr() - self._logger.debug( - f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' - ) - self.gemini_manager.post_iter() - - def backward(self, loss: torch.Tensor): - self._pre_backward() - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - loss.backward() - self._post_backward() - - def backward_by_grad(self, tensor, grad): - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): - torch.autograd.backward(tensor, grad) - self._post_backward() - - def grad_handle(self, p, grad): - empty_grad = torch.empty_like(grad) - free_storage(empty_grad) - with torch._C.DisableTorchFunction(): - chunk = self.chunk_manager.get_chunk(p) - if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: - raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " - "Some unsupported torch function is operated upon this parameter.") - self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) - chunk.copy_tensor_to_chunk_slice(p, grad) - reduced = self.chunk_manager.reduce_chunk(chunk) - if reduced: - if chunk.is_gathered: - chunk.cuda_global_chunk.div_(chunk.pg_size) - else: - chunk.cuda_shard.div_(chunk.pg_size) - # check overflow elements - self.overflow_counter += chunk.has_inf_or_nan - # record l2 norm for gradient clipping - if chunk.l2_norm_flag: - chunk.set_l2_norm() - self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) - return empty_grad - - def zero_grad(self, set_to_none: bool = False) -> None: - self.module.zero_grad(set_to_none=True) - - def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: - for tensor in chunk.get_tensors(): - self.grads_device[tensor] = device - - def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): - """Returns a dictionary containing a whole state of the module. - - Both parameters and persistent buffers (e.g. running averages) are included. - Keys are corresponding parameter and buffer names. - Parameters and buffers set to ``None`` are not included. - - Warning: The non strict state dict would ignore the parameters if the tensors of the parameters - are shared with other parameters which have been included in the dictionary. - When you need to load the state dict, you should set the argument `strict` to False. - - Returns: - dict: - a dictionary containing a whole state of the module - """ - if destination is None: - destination = OrderedDict() - destination._metadata = OrderedDict() - destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) - self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) - - for hook in self._state_dict_hooks.values(): - hook_result = hook(self, destination, prefix, local_metadata) - if hook_result is not None: - destination = hook_result - return destination - - def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: - """ - get param content from chunks. - - Args: - param_list (_type_): a list of torch.nn.Parameters - only_rank_0 (_type_): _description_ - - Returns: - Dict: a dict whose key is param name and value is param with correct payload - """ - # save parameters - param_to_save_data = dict() - chunk_list = self.chunk_manager.get_chunks(param_list) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - record_tensor = torch.empty([0]) - record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) - if record_flag: - record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() - - assert tensor not in param_to_save_data - param_to_save_data[tensor] = record_tensor - - del temp_chunk - return param_to_save_data - - def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): - r"""Saves module state to `destination` dictionary, containing a state - of the module, but not its descendants. This is called on every - submodule in :meth:`~torch.nn.Module.state_dict`. - - In rare cases, subclasses can achieve class-specific behavior by - overriding this method with custom logic. - - Args: - destination (dict): a dict where state will be stored - prefix (str): the prefix for parameters and buffers used in this - module - """ - assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." - - # get copies of fp32 parameters in CPU - param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) - # get the mapping between copies and fp16 parameters - p_mapping = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - name = self.param2name[p] - assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) - record_parameter = param_to_save_data[fp32_p] - p_mapping[p] = record_parameter - for name, param in self.name2param.items(): - if param is not None: - if is_ddp_ignored(param): - # deal with ddp ignored parameters - destination[prefix + name] = param if keep_vars else param.detach() - else: - destination[prefix + name] = p_mapping[param] - del p_mapping - del param_to_save_data - - # save all buffers - for name, buf in self.named_buffers(): - if buf is not None and name not in self._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - # save extra states - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "get_extra_state", - torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: - destination[extra_state_key] = self.get_extra_state() - - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): - r"""Copies parameters and buffers from :attr:`state_dict` into - this module and its descendants. If :attr:`strict` is ``True``, then - the keys of :attr:`state_dict` must exactly match the keys returned - by this module's :meth:`~torch.nn.Module.state_dict` function. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - strict (bool, optional): whether to strictly enforce that the keys - in :attr:`state_dict` match the keys returned by this module's - :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` - - Returns: - ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: - * **missing_keys** is a list of str containing the missing keys - * **unexpected_keys** is a list of str containing the unexpected keys - - Note: - If a parameter or buffer is registered as ``None`` and its corresponding key - exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a - ``RuntimeError``. - """ - missing_keys: List[str] = [] - unexpected_keys: List[str] = [] - error_msgs: List[str] = [] - - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - # mypy isn't aware that "_metadata" exists in state_dict - state_dict._metadata = metadata # type: ignore[attr-defined] - - prefix = '' - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( - '"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) - - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( - self.__class__.__name__, "\n\t".join(error_msgs))) - return _IncompatibleKeys(missing_keys, unexpected_keys) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs): - r"""Copies parameters and buffers from :attr:`state_dict` into only - this module, but not its descendants. This is called on every submodule - in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this - module in input :attr:`state_dict` is provided as :attr:`local_metadata`. - For state dicts without metadata, :attr:`local_metadata` is empty. - Subclasses can achieve class-specific backward compatible loading using - the version number at `local_metadata.get("version", None)`. - - .. note:: - :attr:`state_dict` is not the same object as the input - :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So - it can be modified. - - Args: - state_dict (dict): a dict containing parameters and - persistent buffers. - prefix (str): the prefix for parameters and buffers used in this - module - local_metadata (dict): a dict containing the metadata for this module. - See - strict (bool): whether to strictly enforce that the keys in - :attr:`state_dict` with :attr:`prefix` match the names of - parameters and buffers in this module - missing_keys (list of str): if ``strict=True``, add missing keys to - this list - unexpected_keys (list of str): if ``strict=True``, add unexpected - keys to this list - error_msgs (list of str): error messages should be added to this - list, and will be reported together in - :meth:`~torch.nn.Module.load_state_dict` - """ - for hook in self._load_state_dict_pre_hooks.values(): - hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - - persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} - local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) - local_state = {k: v for k, v in local_name_params if v is not None} - - def load(param_name, dest_tensor, copy_func): - state_key = prefix + param_name - if state_key in state_dict: - input_param = state_dict[state_key] - # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ - if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: - input_param = input_param[0] - if input_param.shape != dest_tensor.shape: - # local shape should match the one in checkpoint - error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' - 'the shape in current model is {}.'.format(state_key, input_param.shape, - dest_tensor.shape)) - return - try: - with torch.no_grad(): - copy_func(input_param) - except Exception as ex: - error_msgs.append('While copying the parameter named "{}", ' - 'whose dimensions in the model are {} and ' - 'whose dimensions in the checkpoint are {}, ' - 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), - input_param.size(), ex.args)) - elif strict: - missing_keys.append(state_key) - - def load_fp32_parameter(chunk_slice, data): - chunk_slice.copy_(data.flatten()) - - for name, param in self.named_parameters(): - if is_ddp_ignored(param): - # deal with ddp ignored parameters - load(name, param, param.copy_) - - fp32_to_name = dict() - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - if p is not None: - name = self.param2name[p] - fp32_to_name[fp32_p] = name - - chunk_list = self.chunk_manager.get_chunks(self.fp32_params) - for chunk in chunk_list: - temp_chunk = get_temp_total_chunk_on_cuda(chunk) - - for tensor, tensor_info in chunk.tensors_info.items(): - parameter_name = fp32_to_name[tensor] - parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] - load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) - - if chunk.is_gathered: - chunk.cuda_global_chunk.copy_(temp_chunk) - elif chunk.cuda_shard is not None: - chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - else: - chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) - - del temp_chunk - - for chunk_32 in chunk_list: - chunk_16 = chunk_32.paired_chunk - assert chunk_16 is not None - chunk_16.optim_update() - - for name, buf in persistent_buffers.items(): - if buf is not None: - load(name, buf, buf.copy_) - - extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX - if getattr(self.__class__, "set_extra_state", - torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: - if extra_state_key in state_dict: - self.set_extra_state(state_dict[extra_state_key]) - elif strict: - missing_keys.append(extra_state_key) - elif strict and (extra_state_key in state_dict): - unexpected_keys.append(extra_state_key) - - if strict: - for key in state_dict.keys(): - if key.startswith(prefix) and key != extra_state_key: - input_name = key[len(prefix):] - if input_name not in local_state: - unexpected_keys.append(key) - - def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): - ddp_pg = ColoProcessGroup() - for p in param_order.generate(): - assert isinstance(p, ColoParameter) - - # gather sharded parameters in the strict ddp mode - if strict_ddp_mode: - if not p.is_replicate(): - p.set_dist_spec(ReplicaSpec()) - p.set_process_group(pg=ddp_pg) - - # ignore the parameters with no gradient - if not p.requires_grad: - self.set_params_to_ignore([p]) - - # move ignored parameters to CUDA - if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) - continue - - # create a fp32 parameter - fp32_data = p.data.float() - fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) - # create a fp16 parameter - p.data = p.data.half() - - # register the fp16 parameter and fp32 parameter in the chunk manager - dp_world_size = p.process_group.dp_world_size() - self.chunk_manager.register_tensor(tensor=p, - group_type='fp16_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - self.chunk_manager.register_tensor(tensor=fp32_p, - group_type='fp32_param', - config_key=dp_world_size, - cpu_offload=cpu_offload, - pin_memory=pin_memory) - - self.fp16_params.append(p) - self.fp32_params.append(fp32_p) - self.grads_device[p] = self.gemini_manager.default_device - - self.chunk_manager.close_all_groups() - - for p, fp32_p in zip(self.fp16_params, self.fp32_params): - chunk_16 = self.chunk_manager.get_chunk(p) - chunk_32 = self.chunk_manager.get_chunk(fp32_p) - chunk_32.init_pair(chunk_16) - - # keep gathered chunks are in CUDA - if chunk_16.keep_gathered: - self.grads_device[p] = get_current_device() - - def _cast_buffers(self): - for buffer in self.module.buffers(): - buffer.data = buffer.cuda() - if torch.is_floating_point(buffer): - buffer.data = buffer.half() diff --git a/colossalai/nn/parallel/gemini_parallel.py b/colossalai/nn/parallel/gemini_parallel.py deleted file mode 100644 index 2c6e15d91..000000000 --- a/colossalai/nn/parallel/gemini_parallel.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Optional - -import torch - -from colossalai.gemini.chunk import init_chunk_manager -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer import MemStats - -from .data_parallel import ZeroDDP - - -class GeminiDDP(ZeroDDP): - - def __init__(self, - module: torch.nn.Module, - device: torch.device, - placement_policy: str = "cpu", - pin_memory: bool = False, - force_outputs_fp32: bool = False, - strict_ddp_mode: bool = False, - search_range_mb: int = 32, - hidden_dim: Optional[int] = None, - min_chunk_size_mb: float = 32, - memstats: Optional[MemStats] = None) -> None: - """ - A torch.Module warpper using ZeRO-DP and Genimi. - ZeRO is for parallel. Gemini is for memory management. - WARNING: The class will modify the module inline! - - Example: - model is initialized under the context of ColoInitContext - >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") - >>> logits = model(x) - >>> loss = criterion(logits, labels) - >>> model.backward(loss) - - Args: - module (torch.nn.Module): the model to be wrapped. - device (torch.device): device to place the model. - placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". - pin_memory (bool, optional): use pin memory on CPU. Defaults to False. - force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. - search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. - hidden_dim (int, optional): the hidden dimension of DNN. - Users can provide this argument to speed up searching. - If users do not know this argument before training, it is ok. We will use a default value 1024. - min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. - If the aggregate size of parameters is still samller than the minimum chunk size, - all parameters will be compacted into one small chunk. - memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. - """ - # some ugly hotfix for the compatibility with Lightning - if search_range_mb is None: - search_range_mb = 32 - - chunk_manager = init_chunk_manager(model=module, - init_device=device, - hidden_dim=hidden_dim, - search_range_mb=search_range_mb, - min_chunk_size_mb=min_chunk_size_mb, - strict_ddp_flag=strict_ddp_mode) - gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/zero/__init__.py b/colossalai/zero/__init__.py index 098ccbb45..3465079e4 100644 --- a/colossalai/zero/__init__.py +++ b/colossalai/zero/__init__.py @@ -1,41 +1,16 @@ -from typing import Tuple +from .gemini import ( + ColoInitContext, + GeminiAdamOptimizer, + GeminiDDP, + ZeroDDP, + ZeroOptimizer, + get_static_torch_model, + post_process_colo_init_ctx, +) +from .low_level import LowLevelZeroOptimizer +from .wrapper import zero_model_wrapper, zero_optim_wrapper -import torch -import torch.nn as nn - -from colossalai.logging import get_dist_logger -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_optim import LowLevelZeroOptimizer, ShardedOptimizerV2 - -from ..nn.optimizer.zero_optimizer import ZeroOptimizer - - -def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, - optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: - """ - A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading - - :param model: Your model object - :type model: :class:`torch.nn.Module` - :param optimizer_config: Your optimizer object - :type optimizer_config: :class:`dict` - - :return: (model, optimizer) - :rtype: Tuple - """ - - logger = get_dist_logger('convert_to_zero_v2') - - logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) - if optimizer_config is None: - optimizer_config = dict() - logger.info(f'model_config is {model_config}', ranks=[0]) - if model_config is None: - model_config = dict() - - zero_model = ShardedModelV2(model, **model_config) - zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) - return zero_model, zero_optimizer - - -__all__ = ['convert_to_zero_v2', 'LowLevelZeroOptimizer', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroOptimizer'] +__all__ = [ + 'ZeroDDP', 'GeminiDDP', 'ZeroOptimizer', 'GeminiAdamOptimizer', 'zero_model_wrapper', 'zero_optim_wrapper', + 'LowLevelZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx', 'get_static_torch_model' +] diff --git a/colossalai/zero/gemini/__init__.py b/colossalai/zero/gemini/__init__.py new file mode 100644 index 000000000..60f85ca2f --- /dev/null +++ b/colossalai/zero/gemini/__init__.py @@ -0,0 +1,11 @@ +from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration +from .colo_init_context import ColoInitContext, post_process_colo_init_ctx +from .gemini_ddp import GeminiDDP, ZeroDDP +from .gemini_mgr import GeminiManager +from .gemini_optimizer import GeminiAdamOptimizer, ZeroOptimizer +from .utils import get_static_torch_model + +__all__ = [ + 'GeminiManager', 'TensorInfo', 'TensorState', 'ChunkManager', 'search_chunk_configuration', 'ZeroDDP', 'GeminiDDP', + 'get_static_torch_model', 'GeminiAdamOptimizer', 'ZeroOptimizer', 'ColoInitContext', 'post_process_colo_init_ctx' +] diff --git a/colossalai/gemini/chunk/__init__.py b/colossalai/zero/gemini/chunk/__init__.py similarity index 100% rename from colossalai/gemini/chunk/__init__.py rename to colossalai/zero/gemini/chunk/__init__.py diff --git a/colossalai/gemini/chunk/chunk.py b/colossalai/zero/gemini/chunk/chunk.py similarity index 100% rename from colossalai/gemini/chunk/chunk.py rename to colossalai/zero/gemini/chunk/chunk.py diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/zero/gemini/chunk/manager.py similarity index 99% rename from colossalai/gemini/chunk/manager.py rename to colossalai/zero/gemini/chunk/manager.py index 2fa65c970..d85df0b00 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/zero/gemini/chunk/manager.py @@ -3,10 +3,11 @@ from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple import torch -from colossalai.gemini.chunk import Chunk, ChunkFullError, TensorState from colossalai.tensor import ColoTensor from colossalai.utils import get_current_device +from .chunk import Chunk, ChunkFullError, TensorState + class ChunkManager: """ diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py similarity index 98% rename from colossalai/gemini/chunk/search_utils.py rename to colossalai/zero/gemini/chunk/search_utils.py index fe9650721..a69b782ea 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -5,9 +5,9 @@ import numpy as np import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.memory_tracer import MemStats, OrderedParamGenerator from colossalai.tensor import ColoParameter from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: diff --git a/colossalai/gemini/chunk/utils.py b/colossalai/zero/gemini/chunk/utils.py similarity index 91% rename from colossalai/gemini/chunk/utils.py rename to colossalai/zero/gemini/chunk/utils.py index 83512b8e0..283f74203 100644 --- a/colossalai/gemini/chunk/utils.py +++ b/colossalai/zero/gemini/chunk/utils.py @@ -5,10 +5,11 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.chunk import ChunkManager -from colossalai.gemini.chunk.search_utils import search_chunk_configuration from colossalai.utils import is_ddp_ignored +from .manager import ChunkManager +from .search_utils import search_chunk_configuration + def safe_div(a, b): if a == 0: diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/zero/gemini/colo_init_context.py similarity index 97% rename from colossalai/utils/model/colo_init_context.py rename to colossalai/zero/gemini/colo_init_context.py index 87ae413a2..5937ee9ef 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/zero/gemini/colo_init_context.py @@ -3,10 +3,8 @@ from typing import Any, Dict, Iterator, Optional, Tuple, Union import torch from torch import nn -from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup - -from .utils import InsertPostInitMethodToModuleSubClasses +from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses # find named_params includes replica @@ -89,6 +87,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._default_dist_spec = default_dist_spec def _register_colo_modules(self): + from colossalai.nn.parallel.layers import ColoEmbedding, ColoLinear, register_colo_module register_colo_module(torch.nn.Linear, ColoLinear()) register_colo_module(torch.nn.Embedding, ColoEmbedding()) diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py new file mode 100644 index 000000000..50f1b1ef1 --- /dev/null +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -0,0 +1,590 @@ +import itertools +from collections import OrderedDict +from functools import partial +from typing import Dict, List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn + +from colossalai.logging import get_dist_logger +from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage +from colossalai.tensor import ProcessGroup as ColoProcessGroup +from colossalai.tensor import ReplicaSpec +from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device, is_ddp_ignored + +from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager +from .gemini_hook import GeminiZeROHook +from .gemini_mgr import GeminiManager +from .memory_tracer import MemStats, OrderedParamGenerator +from .utils import get_temp_total_chunk_on_cuda + +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + +__all__ = [ + 'ZeroDDP', + 'GeminiDDP', +] + + +class ZeroDDP(ColoDDP): + """ZeRO DDP for ColoTensor. + Warning: Nested ZeroDDP is not supported now. + It is designed to be used with ChunkManager and GeminiManager. + For more details, see the API reference of ``ChunkManager`` and ``GeminiManager``. + + Args: + module (torch.nn.Module): Module to apply ZeRO-DP. + gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. + For more details, see the API reference of ``GeminiManager``. + pin_memory (bool): Chunks on CPU Memory use pin-memory. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. + Defaults to False. + strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. + Defaults to False. Users can set it to True, when they clearly know that they only need DDP. + """ + + def __init__(self, + module: torch.nn.Module, + gemini_manager: GeminiManager, + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False) -> None: + super().__init__(module, process_group=ColoProcessGroup()) + self.gemini_manager = gemini_manager + self.chunk_manager: ChunkManager = gemini_manager.chunk_manager + self.force_outputs_fp32 = force_outputs_fp32 + self.param_op_hook = GeminiZeROHook(gemini_manager) + self.fp32_params: List[ColoTensor] = list() + self.fp16_params: List[ColoParameter] = list() + self.overflow_counter = 0 + self.grads_device: Dict[torch.Tensor, torch.device] = dict() + self.param2name: Dict[nn.Parameter, str] = dict() + self.name2param: Dict[str, nn.Parameter] = dict() + + self._cast_buffers() + self._logger = get_dist_logger() + + if self.gemini_manager._premade_memstats_: + # build chunk in param runtime visited order. + param_order = self.gemini_manager.memstats()._param_runtime_order + else: + # build chunk in param initialized order. + # Note: in this way, it can not get filter unused params during runtime. + param_order = OrderedParamGenerator() + for p in module.parameters(): + param_order.append(p) + + self._init_chunks(param_order=param_order, + strict_ddp_mode=strict_ddp_mode, + cpu_offload=self.gemini_manager.policy_name != 'cuda', + pin_memory=pin_memory) + + for name, param in module.named_parameters(): + self.param2name[param] = name + for m_name, m_var in module.named_modules(): + for p_name, p_var in m_var.named_parameters(recurse=False): + param_name = m_name + '.' + p_name if m_name else p_name + self.name2param[param_name] = p_var + + def _post_forward(self): + """This function is only triggered for inference. + """ + access_list = list(self.chunk_manager.accessed_chunks) + # we need to scatter all accessed chunks and move them to their original places + for chunk in access_list: + if chunk.keep_gathered: + self.chunk_manager.fake_release_chunk(chunk) + else: + assert chunk.can_release + self.chunk_manager.release_chunk(chunk) + first_param = next(iter(chunk.tensors_info)) + self.chunk_manager.move_chunk(chunk, self.grads_device[first_param]) + assert self.chunk_manager.accessed_mem == 0 + # reset all recorded attributes + self.gemini_manager.reset_attributes() + + def forward(self, *args, **kwargs): + # check whether we are in a inference mode + grad_flag = torch.is_grad_enabled() + if not grad_flag: + assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + ), "You should run a completed iteration as your warmup iter" + + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + self.module.zero_grad(set_to_none=True) + self.gemini_manager.pre_iter(*args) + with ColoParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + # scatter chunks in the inference mode + if not grad_flag: + self._post_forward() + + if self.force_outputs_fp32: + return _cast_float(outputs, torch.float) + return outputs + + def _setup_grads_ptr(self): + for p in self.module.parameters(): + if is_ddp_ignored(p): + continue + p.grad = None + + def _pre_backward(self): + # set a visit label for all parameters + # the label is used to check whether the parameter is correctly reduced + for param in self.param2name: + if not is_ddp_ignored(param): + setattr(param, "_gemini_reduced", False) + + def _post_backward(self): + if self.chunk_manager.accessed_mem != 0: + error_params = ["Reduction failed at followed parameters:"] + for param in self.param2name: + if not is_ddp_ignored(param) and not getattr(param, "_gemini_reduced"): + error_params.append(self.param2name[param]) + error_str = "\n\t".join(error_params) + raise RuntimeError("ZERO DDP error: the synchronization of gradients doesn't exit properly.", + "The most possible reason is that the model is not compatible with ZeroDDP.\n", + f"{error_str}") + self._setup_grads_ptr() + self._logger.debug( + f'comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}' + ) + self.gemini_manager.post_iter() + + def backward(self, loss: torch.Tensor): + self._pre_backward() + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def backward_by_grad(self, tensor, grad): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + torch.autograd.backward(tensor, grad) + self._post_backward() + + def grad_handle(self, p, grad): + empty_grad = torch.empty_like(grad) + free_storage(empty_grad) + with torch._C.DisableTorchFunction(): + chunk = self.chunk_manager.get_chunk(p) + if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD: + raise RuntimeError(f"Parameter `{self.param2name[p]}` failed at the gradient reduction. " + "Some unsupported torch function is operated upon this parameter.") + self.chunk_manager.trans_tensor_state(p, TensorState.READY_FOR_REDUCE) + chunk.copy_tensor_to_chunk_slice(p, grad) + reduced = self.chunk_manager.reduce_chunk(chunk) + if reduced: + if chunk.is_gathered: + chunk.cuda_global_chunk.div_(chunk.pg_size) + else: + chunk.cuda_shard.div_(chunk.pg_size) + # check overflow elements + self.overflow_counter += chunk.has_inf_or_nan + # record l2 norm for gradient clipping + if chunk.l2_norm_flag: + chunk.set_l2_norm() + self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True) + return empty_grad + + def zero_grad(self, set_to_none: bool = False) -> None: + self.module.zero_grad(set_to_none=True) + + def set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None: + for tensor in chunk.get_tensors(): + self.grads_device[tensor] = device + + def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are included. + Keys are corresponding parameter and buffer names. + Parameters and buffers set to ``None`` are not included. + + Warning: The non strict state dict would ignore the parameters if the tensors of the parameters + are shared with other parameters which have been included in the dictionary. + When you need to load the state dict, you should set the argument `strict` to False. + + Returns: + dict: + a dictionary containing a whole state of the module + """ + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) + self._save_to_state_dict(destination, prefix, keep_vars, only_rank_0) + + for hook in self._state_dict_hooks.values(): + hook_result = hook(self, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: + """ + get param content from chunks. + + Args: + param_list (_type_): a list of torch.nn.Parameters + only_rank_0 (_type_): _description_ + + Returns: + Dict: a dict whose key is param name and value is param with correct payload + """ + # save parameters + param_to_save_data = dict() + chunk_list = self.chunk_manager.get_chunks(param_list) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + record_tensor = torch.empty([0]) + record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) + if record_flag: + record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu() + + assert tensor not in param_to_save_data + param_to_save_data[tensor] = record_tensor + + del temp_chunk + return param_to_save_data + + def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + assert keep_vars is False, "`state_dict` with parameter, `keep_vars=True`, is not supported now." + + # get copies of fp32 parameters in CPU + param_to_save_data = self._get_param_to_save_data(self.fp32_params, only_rank_0) + # get the mapping between copies and fp16 parameters + p_mapping = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + name = self.param2name[p] + assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) + record_parameter = param_to_save_data[fp32_p] + p_mapping[p] = record_parameter + for name, param in self.name2param.items(): + if param is not None: + if is_ddp_ignored(param): + # deal with ddp ignored parameters + destination[prefix + name] = param if keep_vars else param.detach() + else: + destination[prefix + name] = p_mapping[param] + del p_mapping + del param_to_save_data + + # save all buffers + for name, buf in self.named_buffers(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + # save extra states + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", + torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + r"""Copies parameters and buffers from :attr:`state_dict` into + this module and its descendants. If :attr:`strict` is ``True``, then + the keys of :attr:`state_dict` must exactly match the keys returned + by this module's :meth:`~torch.nn.Module.state_dict` function. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + strict (bool, optional): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``True`` + + Returns: + ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields: + * **missing_keys** is a list of str containing the missing keys + * **unexpected_keys** is a list of str containing the unexpected keys + + Note: + If a parameter or buffer is registered as ``None`` and its corresponding key + exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a + ``RuntimeError``. + """ + missing_keys: List[str] = [] + unexpected_keys: List[str] = [] + error_msgs: List[str] = [] + + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + # mypy isn't aware that "_metadata" exists in state_dict + state_dict._metadata = metadata # type: ignore[attr-defined] + + prefix = '' + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + self._load_from_state_dict(state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join( + '"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + return _IncompatibleKeys(missing_keys, unexpected_keys) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self.named_buffers() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + def load(param_name, dest_tensor, copy_func): + state_key = prefix + param_name + if state_key in state_dict: + input_param = state_dict[state_key] + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + if input_param.shape != dest_tensor.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(state_key, input_param.shape, + dest_tensor.shape)) + return + try: + with torch.no_grad(): + copy_func(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(state_key, dest_tensor.size(), + input_param.size(), ex.args)) + elif strict: + missing_keys.append(state_key) + + def load_fp32_parameter(chunk_slice, data): + chunk_slice.copy_(data.flatten()) + + for name, param in self.named_parameters(): + if is_ddp_ignored(param): + # deal with ddp ignored parameters + load(name, param, param.copy_) + + fp32_to_name = dict() + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + if p is not None: + name = self.param2name[p] + fp32_to_name[fp32_p] = name + + chunk_list = self.chunk_manager.get_chunks(self.fp32_params) + for chunk in chunk_list: + temp_chunk = get_temp_total_chunk_on_cuda(chunk) + + for tensor, tensor_info in chunk.tensors_info.items(): + parameter_name = fp32_to_name[tensor] + parameter_slice = temp_chunk[tensor_info.offset:tensor_info.end] + load(parameter_name, tensor, partial(load_fp32_parameter, parameter_slice)) + + if chunk.is_gathered: + chunk.cuda_global_chunk.copy_(temp_chunk) + elif chunk.cuda_shard is not None: + chunk.cuda_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + else: + chunk.cpu_shard.copy_(temp_chunk[chunk.shard_begin:chunk.shard_end]) + + del temp_chunk + + for chunk_32 in chunk_list: + chunk_16 = chunk_32.paired_chunk + assert chunk_16 is not None + chunk_16.optim_update() + + for name, buf in persistent_buffers.items(): + if buf is not None: + load(name, buf, buf.copy_) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", + torch.nn.Module.set_extra_state) is not torch.nn.Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + if input_name not in local_state: + unexpected_keys.append(key) + + def _init_chunks(self, param_order, strict_ddp_mode: bool, cpu_offload: bool, pin_memory: bool): + ddp_pg = ColoProcessGroup() + for p in param_order.generate(): + assert isinstance(p, ColoParameter) + + # gather sharded parameters in the strict ddp mode + if strict_ddp_mode: + if not p.is_replicate(): + p.set_dist_spec(ReplicaSpec()) + p.set_process_group(pg=ddp_pg) + + # ignore the parameters with no gradient + if not p.requires_grad: + self.set_params_to_ignore([p]) + + # move ignored parameters to CUDA + if is_ddp_ignored(p): + p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + continue + + # create a fp32 parameter + fp32_data = p.data.float() + fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) + # create a fp16 parameter + p.data = p.data.half() + + # register the fp16 parameter and fp32 parameter in the chunk manager + dp_world_size = p.process_group.dp_world_size() + self.chunk_manager.register_tensor(tensor=p, + group_type='fp16_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + self.chunk_manager.register_tensor(tensor=fp32_p, + group_type='fp32_param', + config_key=dp_world_size, + cpu_offload=cpu_offload, + pin_memory=pin_memory) + + self.fp16_params.append(p) + self.fp32_params.append(fp32_p) + self.grads_device[p] = self.gemini_manager.default_device + + self.chunk_manager.close_all_groups() + + for p, fp32_p in zip(self.fp16_params, self.fp32_params): + chunk_16 = self.chunk_manager.get_chunk(p) + chunk_32 = self.chunk_manager.get_chunk(fp32_p) + chunk_32.init_pair(chunk_16) + + # keep gathered chunks are in CUDA + if chunk_16.keep_gathered: + self.grads_device[p] = get_current_device() + + def _cast_buffers(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.half() + + +class GeminiDDP(ZeroDDP): + + def __init__(self, + module: torch.nn.Module, + device: torch.device, + placement_policy: str = "cpu", + pin_memory: bool = False, + force_outputs_fp32: bool = False, + strict_ddp_mode: bool = False, + search_range_mb: int = 32, + hidden_dim: Optional[int] = None, + min_chunk_size_mb: float = 32, + memstats: Optional[MemStats] = None) -> None: + """ + A torch.Module warpper using ZeRO-DP and Genimi. + ZeRO is for parallel. Gemini is for memory management. + WARNING: The class will modify the module inline! + + Example: + model is initialized under the context of ColoInitContext + >>> model = GeminiDDP(model, torch.cuda.current_device(), "cuda") + >>> logits = model(x) + >>> loss = criterion(logits, labels) + >>> model.backward(loss) + + Args: + module (torch.nn.Module): the model to be wrapped. + device (torch.device): device to place the model. + placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + pin_memory (bool, optional): use pin memory on CPU. Defaults to False. + force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. + search_range_mb (int, optional): chunk size searching range in MegaByte. Defaults to 32. + hidden_dim (int, optional): the hidden dimension of DNN. + Users can provide this argument to speed up searching. + If users do not know this argument before training, it is ok. We will use a default value 1024. + min_chunk_size_mb (float, optional): the minimum chunk size in MegaByte. + If the aggregate size of parameters is still samller than the minimum chunk size, + all parameters will be compacted into one small chunk. + memstats (MemStats, optional) the memory statistics collector by a runtime memory tracer. + """ + # some ugly hotfix for the compatibility with Lightning + if search_range_mb is None: + search_range_mb = 32 + + chunk_manager = init_chunk_manager(model=module, + init_device=device, + hidden_dim=hidden_dim, + search_range_mb=search_range_mb, + min_chunk_size_mb=min_chunk_size_mb, + strict_ddp_flag=strict_ddp_mode) + gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) + super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode) diff --git a/colossalai/zero/utils/gemini_hook.py b/colossalai/zero/gemini/gemini_hook.py similarity index 95% rename from colossalai/zero/utils/gemini_hook.py rename to colossalai/zero/gemini/gemini_hook.py index bddc307a0..dbc292485 100644 --- a/colossalai/zero/utils/gemini_hook.py +++ b/colossalai/zero/gemini/gemini_hook.py @@ -5,10 +5,10 @@ from typing import List import torch -from colossalai.gemini import TensorState -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.utils import is_ddp_ignored +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.gemini_mgr import GeminiManager class TrainingPhase(Enum): diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/zero/gemini/gemini_mgr.py similarity index 97% rename from colossalai/gemini/gemini_mgr.py rename to colossalai/zero/gemini/gemini_mgr.py index 72a5e4a7f..c38e6eff8 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/zero/gemini/gemini_mgr.py @@ -4,10 +4,8 @@ from typing import List, Optional, Tuple import torch -from colossalai.gemini.chunk import Chunk, ChunkManager -from colossalai.gemini.memory_tracer import MemStats - -from .memory_tracer import ChunkMemStatsCollector +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector, MemStats from .placement_policy import PlacementPolicyFactory diff --git a/colossalai/nn/optimizer/zero_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py similarity index 97% rename from colossalai/nn/optimizer/zero_optimizer.py rename to colossalai/zero/gemini/gemini_optimizer.py index 422ebb7a3..8e0237ddc 100644 --- a/colossalai/nn/optimizer/zero_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -10,12 +10,15 @@ from torch.nn import Parameter from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler -from colossalai.gemini.chunk import Chunk, ChunkManager from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam -from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.utils import disposable, get_current_device, is_ddp_ignored +from .chunk import Chunk, ChunkManager +from .gemini_ddp import ZeroDDP + +__all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] + _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} @@ -316,3 +319,10 @@ class ZeroOptimizer(ColossalaiOptimizer): fake_params_list.append(fake_param) group['params'] = fake_params_list + + +class GeminiAdamOptimizer(ZeroOptimizer): + + def __init__(self, model: torch.nn.Module, **defaults: Any) -> None: + optimizer = HybridAdam(model.parameters(), **defaults) + super().__init__(optimizer, model, **defaults) diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/zero/gemini/memory_tracer/__init__.py similarity index 100% rename from colossalai/gemini/memory_tracer/__init__.py rename to colossalai/zero/gemini/memory_tracer/__init__.py diff --git a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py similarity index 91% rename from colossalai/gemini/memory_tracer/chunk_memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py index 1a5b6bf52..f5eb05b4f 100644 --- a/colossalai/gemini/memory_tracer/chunk_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/chunk_memstats_collector.py @@ -1,10 +1,10 @@ from typing import Optional -from colossalai.gemini.chunk import ChunkManager -from colossalai.gemini.memory_tracer import MemStats from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.chunk import ChunkManager +from .memory_stats import MemStats from .memstats_collector import MemStatsCollector diff --git a/colossalai/gemini/memory_tracer/memory_monitor.py b/colossalai/zero/gemini/memory_tracer/memory_monitor.py similarity index 100% rename from colossalai/gemini/memory_tracer/memory_monitor.py rename to colossalai/zero/gemini/memory_tracer/memory_monitor.py diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/zero/gemini/memory_tracer/memory_stats.py similarity index 98% rename from colossalai/gemini/memory_tracer/memory_stats.py rename to colossalai/zero/gemini/memory_tracer/memory_stats.py index 84fa00fb9..9a45034ee 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/zero/gemini/memory_tracer/memory_stats.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional import torch -from colossalai.gemini.memory_tracer import OrderedParamGenerator +from .param_runtime_order import OrderedParamGenerator class MemStats(object): diff --git a/colossalai/gemini/memory_tracer/memstats_collector.py b/colossalai/zero/gemini/memory_tracer/memstats_collector.py similarity index 92% rename from colossalai/gemini/memory_tracer/memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/memstats_collector.py index d939da6eb..0694be485 100644 --- a/colossalai/gemini/memory_tracer/memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/memstats_collector.py @@ -1,12 +1,7 @@ import time -from typing import List, Optional - -import torch - -from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.utils.memory import colo_device_memory_used +from typing import Optional +from .memory_monitor import SyncCudaMemoryMonitor from .memory_stats import MemStats @@ -49,7 +44,7 @@ class MemStatsCollector: assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \ f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\ - f"step total {self._step_total}" + f"step total {self._step_total}" next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx] self._step_idx = (self._step_idx + 1) % self._step_total return next_non_model_data @@ -75,6 +70,8 @@ class MemStatsCollector: Sampling model data statistics. """ if self._start_flag and not self.use_outside_memstats: + from colossalai.zero.legacy.gemini import StatefulTensor + # The following code work for ZeroInitContext, which is deprecated in v0.1.12 cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda'] self._memstats.record_max_cuda_model_data(cuda_mem) diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/zero/gemini/memory_tracer/param_runtime_order.py similarity index 100% rename from colossalai/gemini/memory_tracer/param_runtime_order.py rename to colossalai/zero/gemini/memory_tracer/param_runtime_order.py diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py similarity index 95% rename from colossalai/gemini/memory_tracer/runtime_mem_tracer.py rename to colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py index a643751da..0c9eac8b6 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/zero/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,9 +1,14 @@ import torch.nn -from colossalai.gemini.memory_tracer import MemStats -from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.zero.legacy.gemini.ophooks.runtime_mem_tracer_hook import ( + GradMemStats, + GradMemTracerHook, + ParamMemTracerHook, +) + +from .memory_stats import MemStats __all__ = ['RuntimeMemTracer'] diff --git a/colossalai/gemini/memory_tracer/static_memstats_collector.py b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py similarity index 98% rename from colossalai/gemini/memory_tracer/static_memstats_collector.py rename to colossalai/zero/gemini/memory_tracer/static_memstats_collector.py index 3209881e1..b8f9a095f 100644 --- a/colossalai/gemini/memory_tracer/static_memstats_collector.py +++ b/colossalai/zero/gemini/memory_tracer/static_memstats_collector.py @@ -6,7 +6,7 @@ from torch.fx import symbolic_trace from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta -from colossalai.gemini.chunk import ChunkManager +from colossalai.zero.gemini.chunk import ChunkManager if is_compatible_with_meta(): from colossalai.fx.profiler import MetaTensor diff --git a/colossalai/gemini/memory_tracer/utils.py b/colossalai/zero/gemini/memory_tracer/utils.py similarity index 100% rename from colossalai/gemini/memory_tracer/utils.py rename to colossalai/zero/gemini/memory_tracer/utils.py diff --git a/colossalai/gemini/placement_policy.py b/colossalai/zero/gemini/placement_policy.py similarity index 98% rename from colossalai/gemini/placement_policy.py rename to colossalai/zero/gemini/placement_policy.py index fed1cc298..84a868872 100644 --- a/colossalai/gemini/placement_policy.py +++ b/colossalai/zero/gemini/placement_policy.py @@ -5,11 +5,12 @@ from typing import Dict, List, Optional, Tuple, Type import torch -from colossalai.gemini.chunk import Chunk, ChunkManager -from colossalai.gemini.memory_tracer import ChunkMemStatsCollector from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from .chunk import Chunk, ChunkManager +from .memory_tracer import ChunkMemStatsCollector + class PlacementPolicy(ABC): need_mem_stats: bool = False diff --git a/colossalai/nn/parallel/utils.py b/colossalai/zero/gemini/utils.py similarity index 97% rename from colossalai/nn/parallel/utils.py rename to colossalai/zero/gemini/utils.py index 08fdb6026..e52b5b836 100644 --- a/colossalai/nn/parallel/utils.py +++ b/colossalai/zero/gemini/utils.py @@ -6,9 +6,10 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.gemini.chunk import Chunk from colossalai.utils import get_current_device +from .chunk import Chunk + def get_temp_total_chunk_on_cuda(chunk: Chunk): if chunk.is_gathered: @@ -77,7 +78,7 @@ def get_static_torch_model(zero_ddp_model, Returns: torch.nn.Module: a static torch model used for saving checkpoints or numeric checks """ - from colossalai.nn.parallel import ZeroDDP + from colossalai.zero.gemini.gemini_ddp import ZeroDDP assert isinstance(zero_ddp_model, ZeroDDP) state_dict = zero_ddp_model.state_dict(only_rank_0=only_rank_0) diff --git a/colossalai/zero/legacy/__init__.py b/colossalai/zero/legacy/__init__.py new file mode 100644 index 000000000..35570a1f5 --- /dev/null +++ b/colossalai/zero/legacy/__init__.py @@ -0,0 +1,44 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from colossalai.logging import get_dist_logger + +from .init_ctx import ZeroInitContext, no_shard_zero_context, no_shard_zero_decrator +from .sharded_model import ShardedModelV2 +from .sharded_optim import ShardedOptimizerV2 + + +def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model_config, + optimizer_config) -> Tuple[ShardedModelV2, ShardedOptimizerV2]: + """ + A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading + + :param model: Your model object + :type model: :class:`torch.nn.Module` + :param optimizer_config: Your optimizer object + :type optimizer_config: :class:`dict` + + :return: (model, optimizer) + :rtype: Tuple + """ + + logger = get_dist_logger('convert_to_zero_v2') + + logger.info(f'optimizer_config is {optimizer_config}', ranks=[0]) + if optimizer_config is None: + optimizer_config = dict() + logger.info(f'model_config is {model_config}', ranks=[0]) + if model_config is None: + model_config = dict() + + zero_model = ShardedModelV2(model, **model_config) + zero_optimizer = ShardedOptimizerV2(zero_model, optimizer, **optimizer_config) + return zero_model, zero_optimizer + + +__all__ = [ + 'convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2', 'ZeroInitContext', 'no_shard_zero_context', + 'no_shard_zero_decrator' +] diff --git a/colossalai/zero/legacy/gemini/__init__.py b/colossalai/zero/legacy/gemini/__init__.py new file mode 100644 index 000000000..754ae9bc0 --- /dev/null +++ b/colossalai/zero/legacy/gemini/__init__.py @@ -0,0 +1,9 @@ +from .ophooks import BaseOpHook, register_ophooks_recursively +from .stateful_tensor import StatefulTensor +from .stateful_tensor_mgr import StatefulTensorMgr +from .tensor_placement_policy import AutoTensorPlacementPolicy, CPUTensorPlacementPolicy, CUDATensorPlacementPolicy + +__all__ = [ + 'StatefulTensorMgr', 'StatefulTensor', 'CPUTensorPlacementPolicy', 'CUDATensorPlacementPolicy', + 'AutoTensorPlacementPolicy', 'register_ophooks_recursively', 'BaseOpHook' +] diff --git a/colossalai/gemini/gemini_context.py b/colossalai/zero/legacy/gemini/gemini_context.py similarity index 100% rename from colossalai/gemini/gemini_context.py rename to colossalai/zero/legacy/gemini/gemini_context.py diff --git a/colossalai/gemini/ophooks/__init__.py b/colossalai/zero/legacy/gemini/ophooks/__init__.py similarity index 100% rename from colossalai/gemini/ophooks/__init__.py rename to colossalai/zero/legacy/gemini/ophooks/__init__.py diff --git a/colossalai/gemini/ophooks/_shard_grad_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py similarity index 100% rename from colossalai/gemini/ophooks/_shard_grad_ophook.py rename to colossalai/zero/legacy/gemini/ophooks/_shard_grad_ophook.py diff --git a/colossalai/gemini/ophooks/_shard_param_ophook.py b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py similarity index 99% rename from colossalai/gemini/ophooks/_shard_param_ophook.py rename to colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py index 57f76970c..80736d140 100644 --- a/colossalai/gemini/ophooks/_shard_param_ophook.py +++ b/colossalai/zero/legacy/gemini/ophooks/_shard_param_ophook.py @@ -1,4 +1,5 @@ import torch + from colossalai.registry import OPHOOKS from . import BaseOpHook diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py similarity index 96% rename from colossalai/gemini/ophooks/runtime_mem_tracer_hook.py rename to colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py index 6d0df4e61..f40d6ced1 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/zero/legacy/gemini/ophooks/runtime_mem_tracer_hook.py @@ -5,9 +5,9 @@ from typing import List import torch -from colossalai.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor -from colossalai.gemini.tensor_utils import alloc_storage, free_storage from colossalai.tensor.param_op_hook import ColoParamOpHook +from colossalai.zero.gemini.memory_tracer import MemStats, SyncCudaMemoryMonitor +from colossalai.zero.legacy.gemini.tensor_utils import alloc_storage, free_storage class TrainingPhase(Enum): diff --git a/colossalai/gemini/ophooks/utils.py b/colossalai/zero/legacy/gemini/ophooks/utils.py similarity index 100% rename from colossalai/gemini/ophooks/utils.py rename to colossalai/zero/legacy/gemini/ophooks/utils.py diff --git a/colossalai/gemini/paramhooks/__init__.py b/colossalai/zero/legacy/gemini/paramhooks/__init__.py similarity index 100% rename from colossalai/gemini/paramhooks/__init__.py rename to colossalai/zero/legacy/gemini/paramhooks/__init__.py diff --git a/colossalai/gemini/paramhooks/_param_hookmgr.py b/colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py similarity index 100% rename from colossalai/gemini/paramhooks/_param_hookmgr.py rename to colossalai/zero/legacy/gemini/paramhooks/_param_hookmgr.py diff --git a/colossalai/gemini/stateful_tensor.py b/colossalai/zero/legacy/gemini/stateful_tensor.py similarity index 97% rename from colossalai/gemini/stateful_tensor.py rename to colossalai/zero/legacy/gemini/stateful_tensor.py index 18fc8fd14..1619ae407 100644 --- a/colossalai/gemini/stateful_tensor.py +++ b/colossalai/zero/legacy/gemini/stateful_tensor.py @@ -1,9 +1,9 @@ from enum import Enum -from typing import Optional -import torch -from typing import Union +from typing import Optional, Union -from colossalai.gemini.gemini_context import GeminiMemoryManager +import torch + +from .gemini_context import GeminiMemoryManager def sizeof_tensor(tensor: torch.Tensor): @@ -19,7 +19,7 @@ class TensorState(Enum): class StatefulTensor(object): - """A Structure stores a Torch Tensor and labeled states. + """A Structure stores a Torch Tensor and labeled states. Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management diff --git a/colossalai/gemini/stateful_tensor_mgr.py b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py similarity index 94% rename from colossalai/gemini/stateful_tensor_mgr.py rename to colossalai/zero/legacy/gemini/stateful_tensor_mgr.py index c300f9bff..3b37444b0 100644 --- a/colossalai/gemini/stateful_tensor_mgr.py +++ b/colossalai/zero/legacy/gemini/stateful_tensor_mgr.py @@ -1,13 +1,16 @@ import functools -import torch import types -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy -from typing import List -from colossalai.logging import get_dist_logger from time import time +from typing import List + +import torch + +from colossalai.logging import get_dist_logger +from colossalai.utils.cuda import get_current_device + +from .stateful_tensor import StatefulTensor, TensorState +from .tensor_placement_policy import TensorPlacementPolicy +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage class StatefulTensorMgr(object): diff --git a/colossalai/gemini/tensor_placement_policy.py b/colossalai/zero/legacy/gemini/tensor_placement_policy.py similarity index 96% rename from colossalai/gemini/tensor_placement_policy.py rename to colossalai/zero/legacy/gemini/tensor_placement_policy.py index 0e575254c..165ae51fe 100644 --- a/colossalai/gemini/tensor_placement_policy.py +++ b/colossalai/zero/legacy/gemini/tensor_placement_policy.py @@ -5,11 +5,12 @@ from typing import List, Optional, Type import torch -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_capacity +from colossalai.zero.gemini.memory_tracer import MemStatsCollector + +from .stateful_tensor import StatefulTensor +from .tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage class TensorPlacementPolicy(ABC): diff --git a/colossalai/gemini/tensor_utils.py b/colossalai/zero/legacy/gemini/tensor_utils.py similarity index 97% rename from colossalai/gemini/tensor_utils.py rename to colossalai/zero/legacy/gemini/tensor_utils.py index bcc159f99..b7f23e025 100644 --- a/colossalai/gemini/tensor_utils.py +++ b/colossalai/zero/legacy/gemini/tensor_utils.py @@ -1,6 +1,8 @@ +from typing import Tuple, Union + import torch -from colossalai.gemini.stateful_tensor import StatefulTensor -from typing import Union, Tuple + +from .stateful_tensor import StatefulTensor def is_storage_empty(tensor: torch.Tensor) -> bool: diff --git a/colossalai/zero/init_ctx/__init__.py b/colossalai/zero/legacy/init_ctx/__init__.py similarity index 100% rename from colossalai/zero/init_ctx/__init__.py rename to colossalai/zero/legacy/init_ctx/__init__.py diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py similarity index 97% rename from colossalai/zero/init_ctx/init_context.py rename to colossalai/zero/legacy/init_ctx/init_context.py index b40b69962..f8be0ca4f 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -13,10 +13,10 @@ from colossalai.context.singleton_meta import SingletonMeta from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 -from colossalai.zero.sharded_param import ShardedParamV2 +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 +from colossalai.zero.legacy.sharded_param import ShardedParamV2 @dataclass diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/legacy/shard_utils/__init__.py similarity index 100% rename from colossalai/zero/shard_utils/__init__.py rename to colossalai/zero/legacy/shard_utils/__init__.py diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py similarity index 87% rename from colossalai/zero/shard_utils/base_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/base_shard_strategy.py index 7c2f4c9f6..7ca951091 100644 --- a/colossalai/zero/shard_utils/base_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/base_shard_strategy.py @@ -2,7 +2,8 @@ from abc import ABC, abstractmethod from typing import List, Optional import torch.distributed as dist -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor + +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class BaseShardStrategy(ABC): diff --git a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py similarity index 89% rename from colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py index a7bd7cf53..11297bf6d 100644 --- a/colossalai/zero/shard_utils/bucket_tensor_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/bucket_tensor_shard_strategy.py @@ -2,17 +2,18 @@ from typing import List, Optional import torch import torch.distributed as dist -from colossalai.utils import get_current_device -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from torch._utils import _flatten_dense_tensors as flatten +from colossalai.utils import get_current_device +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor + from .tensor_shard_strategy import TensorShardStrategy class BucketTensorShardStrategy(TensorShardStrategy): - """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, - which will fully utilize network bandwidth. - It is especially useful when sub-module contains bias, + """Use the same shard scheme as `TensorShardStrategy`'s, but it gathers tensors of a sub-module together, + which will fully utilize network bandwidth. + It is especially useful when sub-module contains bias, since we cannot utilize network bandwidth well if we only gather a bias tensor (bias is usaully small). """ diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/legacy/shard_utils/commons.py similarity index 95% rename from colossalai/zero/shard_utils/commons.py rename to colossalai/zero/legacy/shard_utils/commons.py index 71cef44c1..bf5ae325c 100644 --- a/colossalai/zero/shard_utils/commons.py +++ b/colossalai/zero/legacy/shard_utils/commons.py @@ -1,7 +1,7 @@ -import torch -import torch.nn.functional as F from typing import Tuple +import torch + def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.Tensor, int]: """Return the local shard of a full tensor.""" diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py similarity index 86% rename from colossalai/zero/shard_utils/tensor_shard_strategy.py rename to colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py index 5bdd95400..d1df4803b 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/legacy/shard_utils/tensor_shard_strategy.py @@ -2,11 +2,12 @@ from typing import List, Optional import torch import torch.distributed as dist + from colossalai.utils import get_current_device -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.shard_utils.commons import get_shard -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.shard_utils.commons import get_shard +from colossalai.zero.legacy.sharded_param.sharded_tensor import ShardedTensor class TensorShardStrategy(BaseShardStrategy): @@ -27,7 +28,7 @@ class TensorShardStrategy(BaseShardStrategy): Args: t (ShardedTensor): a tensor to be sharded. - process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. + process_group (Optional[dist.ProcessGroup], optional): the process group among which tensor shards. Defaults to None. """ if t.is_sharded: diff --git a/colossalai/zero/sharded_model/__init__.py b/colossalai/zero/legacy/sharded_model/__init__.py similarity index 61% rename from colossalai/zero/sharded_model/__init__.py rename to colossalai/zero/legacy/sharded_model/__init__.py index 725179295..93120bdc3 100644 --- a/colossalai/zero/sharded_model/__init__.py +++ b/colossalai/zero/legacy/sharded_model/__init__.py @@ -1,3 +1,3 @@ from .sharded_model_v2 import ShardedModelV2 -__all__ = ['ShardedModelV2'] \ No newline at end of file +__all__ = ['ShardedModelV2'] diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py similarity index 95% rename from colossalai/zero/sharded_model/_utils.py rename to colossalai/zero/legacy/sharded_model/_utils.py index 85a3ab73d..2bd01531a 100644 --- a/colossalai/zero/sharded_model/_utils.py +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -1,9 +1,9 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Tuple, Union import torch import torch.nn.functional as F -from typing import Union -from colossalai.gemini.stateful_tensor import StatefulTensor + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor def get_gradient_predivide_factor(world_size: int) -> float: diff --git a/colossalai/zero/sharded_model/reduce_scatter.py b/colossalai/zero/legacy/sharded_model/reduce_scatter.py similarity index 100% rename from colossalai/zero/sharded_model/reduce_scatter.py rename to colossalai/zero/legacy/sharded_model/reduce_scatter.py diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py similarity index 97% rename from colossalai/zero/sharded_model/sharded_model_v2.py rename to colossalai/zero/legacy/sharded_model/sharded_model_v2.py index 12e8f65d4..edd2cc8e6 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -13,19 +13,18 @@ from torch.nn.parameter import Parameter from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector -from colossalai.gemini.ophooks import register_ophooks_recursively -from colossalai.gemini.paramhooks import BaseParamHookMgr -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr -from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory -from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu from colossalai.logging import get_dist_logger from colossalai.utils import disposable, get_current_device from colossalai.utils.memory import colo_device_memory_capacity -from colossalai.zero.shard_utils import BaseShardStrategy -from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer -from colossalai.zero.utils import ZeroHook +from colossalai.zero.gemini.memory_tracer import MemStatsCollector, StaticMemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import register_ophooks_recursively +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.gemini.tensor_placement_policy import TensorPlacementPolicy, TensorPlacementPolicyFactory +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_move_to_cpu +from colossalai.zero.legacy.shard_utils import BaseShardStrategy +from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBucketer from ._utils import ( cast_float_arguments, @@ -35,6 +34,7 @@ from ._utils import ( free_storage, get_gradient_predivide_factor, ) +from .zero_hook import ZeroHook try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/legacy/sharded_model/utils.py similarity index 91% rename from colossalai/zero/sharded_model/utils.py rename to colossalai/zero/legacy/sharded_model/utils.py index 69f5a23ac..08806e78e 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/legacy/sharded_model/utils.py @@ -1,8 +1,9 @@ -import torch -from colossalai.zero.sharded_model import ShardedModelV2 - import copy +import torch + +from colossalai.zero.legacy.sharded_model import ShardedModelV2 + def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Module): """ diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/legacy/sharded_model/zero_hook.py similarity index 92% rename from colossalai/zero/utils/zero_hook.py rename to colossalai/zero/legacy/sharded_model/zero_hook.py index 87bf2c0f5..50f4bdfc7 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/legacy/sharded_model/zero_hook.py @@ -3,14 +3,14 @@ from typing import Optional import torch import torch.distributed as dist -from colossalai.gemini.memory_tracer import MemStatsCollector -from colossalai.gemini.ophooks import BaseOpHook -from colossalai.gemini.stateful_tensor import TensorState -from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr from colossalai.logging import get_dist_logger from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device -from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.gemini.memory_tracer import MemStatsCollector +from colossalai.zero.legacy.gemini.ophooks import BaseOpHook +from colossalai.zero.legacy.gemini.stateful_tensor import TensorState +from colossalai.zero.legacy.gemini.stateful_tensor_mgr import StatefulTensorMgr +from colossalai.zero.legacy.shard_utils import BaseShardStrategy @OPHOOKS.register_module diff --git a/colossalai/zero/legacy/sharded_optim/__init__.py b/colossalai/zero/legacy/sharded_optim/__init__.py new file mode 100644 index 000000000..b71a70aef --- /dev/null +++ b/colossalai/zero/legacy/sharded_optim/__init__.py @@ -0,0 +1,3 @@ +from .sharded_optim_v2 import ShardedOptimizerV2 + +__all__ = ['ShardedOptimizerV2'] diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py similarity index 97% rename from colossalai/zero/sharded_optim/sharded_optim_v2.py rename to colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py index 43a0b7d76..7ce1c056f 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -14,13 +14,13 @@ from torch.optim import Optimizer from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_placement_policy import AutoTensorPlacementPolicy +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp32 class OptimState(Enum): diff --git a/colossalai/zero/legacy/sharded_param/__init__.py b/colossalai/zero/legacy/sharded_param/__init__.py new file mode 100644 index 000000000..47e2ce2fa --- /dev/null +++ b/colossalai/zero/legacy/sharded_param/__init__.py @@ -0,0 +1,4 @@ +from .sharded_param import ShardedParamV2 +from .sharded_tensor import ShardedTensor + +__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/legacy/sharded_param/sharded_param.py similarity index 93% rename from colossalai/zero/sharded_param/sharded_param.py rename to colossalai/zero/legacy/sharded_param/sharded_param.py index db0f2d149..4bcc4b621 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/legacy/sharded_param/sharded_param.py @@ -1,9 +1,11 @@ +from typing import List, Optional, Tuple + import torch -from typing import Optional, Tuple -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.gemini.tensor_utils import colo_tensor_mem_usage -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState -from typing import List + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState +from colossalai.zero.legacy.gemini.tensor_utils import colo_tensor_mem_usage + +from .sharded_tensor import ShardedTensor EMPTY_TENSOR_DICT = {} diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/legacy/sharded_param/sharded_tensor.py similarity index 92% rename from colossalai/zero/sharded_param/sharded_tensor.py rename to colossalai/zero/legacy/sharded_param/sharded_tensor.py index 77f4aec30..af6031260 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/legacy/sharded_param/sharded_tensor.py @@ -1,5 +1,6 @@ import torch -from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): diff --git a/colossalai/zero/low_level/__init__.py b/colossalai/zero/low_level/__init__.py new file mode 100644 index 000000000..ae3c1de3a --- /dev/null +++ b/colossalai/zero/low_level/__init__.py @@ -0,0 +1,3 @@ +from .low_level_optim import LowLevelZeroOptimizer + +__all__ = ['LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_optim/_utils.py b/colossalai/zero/low_level/_utils.py similarity index 100% rename from colossalai/zero/sharded_optim/_utils.py rename to colossalai/zero/low_level/_utils.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/__init__.py b/colossalai/zero/low_level/bookkeeping/__init__.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/__init__.py rename to colossalai/zero/low_level/bookkeeping/__init__.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/base_store.py b/colossalai/zero/low_level/bookkeeping/base_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/base_store.py rename to colossalai/zero/low_level/bookkeeping/base_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/bucket_store.py rename to colossalai/zero/low_level/bookkeeping/bucket_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/gradient_store.py b/colossalai/zero/low_level/bookkeeping/gradient_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/gradient_store.py rename to colossalai/zero/low_level/bookkeeping/gradient_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/parameter_store.py b/colossalai/zero/low_level/bookkeeping/parameter_store.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/parameter_store.py rename to colossalai/zero/low_level/bookkeeping/parameter_store.py diff --git a/colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py b/colossalai/zero/low_level/bookkeeping/tensor_bucket.py similarity index 100% rename from colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py rename to colossalai/zero/low_level/bookkeeping/tensor_bucket.py diff --git a/colossalai/zero/sharded_optim/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py similarity index 100% rename from colossalai/zero/sharded_optim/low_level_optim.py rename to colossalai/zero/low_level/low_level_optim.py diff --git a/colossalai/zero/sharded_optim/__init__.py b/colossalai/zero/sharded_optim/__init__.py deleted file mode 100644 index 30c26fb75..000000000 --- a/colossalai/zero/sharded_optim/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .low_level_optim import LowLevelZeroOptimizer -from .sharded_optim_v2 import ShardedOptimizerV2 - -__all__ = ['ShardedOptimizerV2', 'LowLevelZeroOptimizer'] diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py deleted file mode 100644 index 5642a504a..000000000 --- a/colossalai/zero/sharded_param/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 - -__all__ = ['ShardedTensor', 'ShardedParamV2'] diff --git a/colossalai/zero/utils/__init__.py b/colossalai/zero/utils/__init__.py deleted file mode 100644 index c4e687228..000000000 --- a/colossalai/zero/utils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .zero_hook import ZeroHook - -__all__ = ['ZeroHook'] \ No newline at end of file diff --git a/colossalai/nn/parallel/zero_wrapper.py b/colossalai/zero/wrapper.py similarity index 95% rename from colossalai/nn/parallel/zero_wrapper.py rename to colossalai/zero/wrapper.py index be8d1da7c..4553249e2 100644 --- a/colossalai/nn/parallel/zero_wrapper.py +++ b/colossalai/zero/wrapper.py @@ -4,7 +4,7 @@ from typing import Dict, Optional import torch import torch.nn as nn -from .gemini_parallel import GeminiDDP +from .gemini import GeminiDDP def zero_model_wrapper(model: nn.Module, zero_stage: int = 1, gemini_config: Optional[Dict] = None): @@ -99,11 +99,11 @@ def zero_optim_wrapper(model: nn.Module, config_dict['max_scale'] = max_scale if zero_stage in [1, 2]: - from colossalai.zero.sharded_optim.low_level_optim import LowLevelZeroOptimizer + from colossalai.zero.low_level import LowLevelZeroOptimizer config_dict['partition_grad'] = zero_stage == 2 config_dict['clip_grad_norm'] = max_norm return LowLevelZeroOptimizer(optimizer, **config_dict) else: - from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer + from colossalai.zero.gemini.gemini_optimizer import ZeroOptimizer config_dict['clipping_norm'] = max_norm return ZeroOptimizer(optimizer, model, **config_dict) diff --git a/docs/source/en/features/nvme_offload.md b/docs/source/en/features/nvme_offload.md index 2933c3db6..38d2c4af9 100644 --- a/docs/source/en/features/nvme_offload.md +++ b/docs/source/en/features/nvme_offload.md @@ -78,7 +78,7 @@ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext ``` diff --git a/docs/source/zh-Hans/features/nvme_offload.md b/docs/source/zh-Hans/features/nvme_offload.md index f33474efa..fd75ed1f5 100644 --- a/docs/source/zh-Hans/features/nvme_offload.md +++ b/docs/source/zh-Hans/features/nvme_offload.md @@ -77,7 +77,7 @@ from transformers.models.gpt2.configuration_gpt2 import GPT2Config from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper +from colossalai.zero import zero_model_wrapper, zero_optim_wrapper from colossalai.utils.model.colo_init_context import ColoInitContext ``` diff --git a/examples/images/dreambooth/debug.py b/examples/images/dreambooth/debug.py index c4adb4823..33219b2ca 100644 --- a/examples/images/dreambooth/debug.py +++ b/examples/images/dreambooth/debug.py @@ -5,7 +5,7 @@ torchrun --standalone --nproc_per_node=1 debug.py from diffusers import AutoencoderKL import colossalai -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, post_process_colo_init_ctx path = "/data/scratch/diffuser/stable-diffusion-v1-4" diff --git a/examples/images/dreambooth/train_dreambooth_colossalai.py b/examples/images/dreambooth/train_dreambooth_colossalai.py index 5c4c86bc7..e6159e105 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai.py @@ -21,10 +21,9 @@ import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() diff --git a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py index 3d789ae2c..1b2fc778d 100644 --- a/examples/images/dreambooth/train_dreambooth_colossalai_lora.py +++ b/examples/images/dreambooth/train_dreambooth_colossalai_lora.py @@ -23,10 +23,9 @@ import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer +from colossalai.zero.gemini import get_static_torch_model disable_existing_loggers() logger = get_dist_logger() diff --git a/examples/images/vit/test_vit.py b/examples/images/vit/test_vit.py index 90f2475b8..6a587e1df 100644 --- a/examples/images/vit/test_vit.py +++ b/examples/images/vit/test_vit.py @@ -18,7 +18,7 @@ from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, Proc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext def set_seed(seed): diff --git a/examples/images/vit/train.py b/examples/images/vit/train.py index 0b4489244..b42cf2bed 100644 --- a/examples/images/vit/train.py +++ b/examples/images/vit/train.py @@ -19,7 +19,7 @@ from colossalai.nn.optimizer import HybridAdam from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext def init_1d_row_for_linear_weight_spec(model, world_size: int): diff --git a/examples/language/bert/train_bert_demo.py b/examples/language/bert/train_bert_demo.py index b690ff787..9a0278b2c 100644 --- a/examples/language/bert/train_bert_demo.py +++ b/examples/language/bert/train_bert_demo.py @@ -12,10 +12,9 @@ from transformers import AlbertConfig, AlbertForSequenceClassification, BertConf import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ diff --git a/examples/language/gpt/gemini/train_gpt_demo.py b/examples/language/gpt/gemini/train_gpt_demo.py index f46226bce..b2a7fa36d 100644 --- a/examples/language/gpt/gemini/train_gpt_demo.py +++ b/examples/language/gpt/gemini/train_gpt_demo.py @@ -13,10 +13,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper CAI_VERSION = colossalai.__version__ diff --git a/examples/language/opt/train_gemini_opt.py b/examples/language/opt/train_gemini_opt.py index 4993ce25d..4874f831c 100755 --- a/examples/language/opt/train_gemini_opt.py +++ b/examples/language/opt/train_gemini_opt.py @@ -34,12 +34,9 @@ from transformers.utils.versions import require_version import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP -from colossalai.utils import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext - from colossalai.tensor import ProcessGroup, ShardSpec +from colossalai.utils import get_current_device +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP def get_data(batch_size, seq_len, vocab_size): @@ -179,13 +176,15 @@ def main(): # build model if args.model_name_or_path is None: logger.info("Train a new model from scratch", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM(config) else: logger.info("Finetune a pre-trained model", ranks=[0]) - with ColoInitContext(device=init_dev, dtype=torch.half, + with ColoInitContext(device=init_dev, + dtype=torch.half, default_dist_spec=default_dist_spec, default_pg=shard_pg): model = OPTForCausalLM.from_pretrained(args.model_name_or_path, @@ -198,8 +197,11 @@ def main(): numel = sum([p.numel() for p in model.parameters()]) PLACEMENT_POLICY = 'cpu' - model = GeminiDDP(model, device=get_current_device(), placement_policy=PLACEMENT_POLICY, - pin_memory=True, strict_ddp_mode=args.shardinit) + model = GeminiDDP(model, + device=get_current_device(), + placement_policy=PLACEMENT_POLICY, + pin_memory=True, + strict_ddp_mode=args.shardinit) optimizer = GeminiAdamOptimizer(model, lr=args.learning_rate, initial_scale=2**14, gpu_margin_mem_ratio=0.0) SEQ_LEN = 1024 diff --git a/examples/language/palm/train.py b/examples/language/palm/train.py index 2f012780d..7923e4fc8 100644 --- a/examples/language/palm/train.py +++ b/examples/language/palm/train.py @@ -15,11 +15,9 @@ from torch.utils.data import DataLoader, Dataset import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.utils import MultiTimer, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP # constants @@ -127,7 +125,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: return model -## Parameter Sharding Strategies for Tensor Parallelism +# Parameter Sharding Strategies for Tensor Parallelism def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup): spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) param.set_tensor_spec(*spec) @@ -232,7 +230,7 @@ if args.distplan == "colossalai": tensor_parallelize(model, pg) model = gemini_zero_dpp(model, pg, args.placement) - #optimizer + # optimizer #optimizer = GeminiAdamOptimizer(model, lr=1e-7, initial_scale=2**5) optimizer = GeminiAdamOptimizer(model, lr=LEARNING_RATE, initial_scale=2**5) diff --git a/examples/language/roberta/pretraining/run_pretraining.py b/examples/language/roberta/pretraining/run_pretraining.py index 9840a122c..eef7bb6ad 100644 --- a/examples/language/roberta/pretraining/run_pretraining.py +++ b/examples/language/roberta/pretraining/run_pretraining.py @@ -1,69 +1,67 @@ -import colossalai import math -import torch -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -import colossalai.nn as col_nn -from arguments import parse_args -from pretrain_utils import get_model, get_optimizer, get_lr_scheduler, save_ckpt -from utils.exp_util import get_tflops, get_mem_info, throughput_calculator, log_args -from utils.global_vars import set_global_variables, get_timers, get_tensorboard_writer -from utils.logger import Logger -from evaluation import evaluate -from loss import LossForPretraining - -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider -from tqdm import tqdm import os import time from functools import partial +import torch +from arguments import parse_args +from evaluation import evaluate +from loss import LossForPretraining +from nvidia_bert_dataset_provider import NvidiaBertDatasetProvider +from pretrain_utils import get_lr_scheduler, get_model, get_optimizer, save_ckpt +from tqdm import tqdm from transformers import AutoTokenizer +from utils.exp_util import get_mem_info, get_tflops, log_args, throughput_calculator +from utils.global_vars import get_tensorboard_writer, get_timers, set_global_variables +from utils.logger import Logger -from colossalai.gemini import ChunkManager, GeminiManager -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.utils import get_current_device -from colossalai.nn.parallel import ZeroDDP -from colossalai.zero import ZeroOptimizer -from colossalai.tensor import ProcessGroup +import colossalai +import colossalai.nn as col_nn +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device +from colossalai.zero import ZeroOptimizer +from colossalai.zero.gemini import ChunkManager, ColoInitContext, GeminiManager +from colossalai.zero.legacy import ShardedModelV2, ShardedOptimizerV2, ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy def main(): args = parse_args() launch_time = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) - + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) os.environ['CUDA_LAUNCH_BLOCKING'] = '1' - + logger = Logger(os.path.join(args.log_path, launch_time), cuda=torch.cuda.is_available(), debug=args.vscode_debug) - + if args.vscode_debug: colossalai.launch(config={}, - rank=args.rank, - world_size=args.world_size, - host=args.host, - port=args.port, - backend=args.backend) + rank=args.rank, + world_size=args.world_size, + host=args.host, + port=args.port, + backend=args.backend) args.local_rank = -1 args.log_interval = 1 else: - colossalai.launch_from_torch(args.colossal_config) #args.colossal_config + colossalai.launch_from_torch(args.colossal_config) # args.colossal_config args.local_rank = int(os.environ["LOCAL_RANK"]) - logger.info(f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + - f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}') + logger.info( + f'launch_from_torch, world size: {torch.distributed.get_world_size()} | ' + + f'ParallelMode.MODEL: {ParallelMode.MODEL} | ParallelMode.DATA: {ParallelMode.DATA} | ParallelMode.TENSOR: {ParallelMode.TENSOR}' + ) log_args(logger, args) args.tokenizer = tokenizer args.logger = logger set_global_variables(launch_time, args.tensorboard_path) - + use_zero = hasattr(gpc.config, 'zero') world_size = torch.distributed.get_world_size() @@ -71,8 +69,8 @@ def main(): if use_zero: shard_strategy = TensorShardStrategy() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=shard_strategy, - shard_param=True): - + shard_param=True): + config, model, numel = get_model(args, logger) # model = ShardedModelV2(model, shard_strategy, tensor_placement_policy='cpu', reuse_fp16_shard=True) else: @@ -82,9 +80,10 @@ def main(): os.mkdir(os.path.join(args.ckpt_path, launch_time)) logger.info(f'Model numel: {numel}') - + get_tflops_func = partial(get_tflops, numel, args.train_micro_batch_size_per_gpu, args.max_seq_length) - steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size #len(dataloader) + # len(dataloader) + steps_per_epoch = 144003367 // world_size // args.train_micro_batch_size_per_gpu // args.gradient_accumulation_steps // args.refresh_bucket_size total_steps = steps_per_epoch * args.epoch # build optimizer and lr_scheduler @@ -98,18 +97,23 @@ def main(): o_l_state_dict['lr_scheduler']['last_epoch'] = o_l_state_dict['lr_scheduler']['last_epoch'] - 1 optimizer = get_optimizer(model, lr=args.lr) optimizer.load_state_dict(o_l_state_dict['optimizer']) - lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) #o_l_state_dict['lr_scheduler']['last_epoch'] + # o_l_state_dict['lr_scheduler']['last_epoch'] + lr_scheduler = get_lr_scheduler(optimizer, + total_steps=total_steps, + last_epoch=o_l_state_dict['lr_scheduler']['last_epoch']) for state in optimizer.state.values(): for k, v in state.items(): if isinstance(v, torch.Tensor): state[k] = v.cuda(f"cuda:{torch.cuda.current_device()}") # if you want delete the above three code, have to move the model to gpu, because in optimizer.step() lr_scheduler.load_state_dict(o_l_state_dict['lr_scheduler']) - + start_epoch = o_l_state_dict['epoch'] start_shard = o_l_state_dict['shard'] + 1 # global_step = o_l_state_dict['global_step'] + 1 - logger.info(f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}') + logger.info( + f'resume from epoch {start_epoch} shard {start_shard} step {lr_scheduler.last_epoch} lr {lr_scheduler.get_last_lr()[0]}' + ) else: optimizer = get_optimizer(model, lr=args.lr) lr_scheduler = get_lr_scheduler(optimizer, total_steps=total_steps, last_epoch=-1) @@ -124,12 +128,11 @@ def main(): # initialize with colossalai engine, _, _, lr_scheduelr = colossalai.initialize(model=model, - optimizer=optimizer, - criterion=criterion, - lr_scheduler=lr_scheduler) - + optimizer=optimizer, + criterion=criterion, + lr_scheduler=lr_scheduler) + logger.info(get_mem_info(prefix='After init model, ')) - best_loss = None eval_loss = 0 @@ -146,13 +149,16 @@ def main(): dataset_iterator, total_length = pretrain_dataset_provider.get_shard(shard) # pretrain_dataset_provider.prefetch_shard(shard + 1) # may cause cpu memory overload if torch.distributed.get_rank() == 0: - iterator_data = tqdm(enumerate(dataset_iterator), total=(total_length // args.train_micro_batch_size_per_gpu // world_size), colour='cyan', smoothing=1) + iterator_data = tqdm(enumerate(dataset_iterator), + total=(total_length // args.train_micro_batch_size_per_gpu // world_size), + colour='cyan', + smoothing=1) else: iterator_data = enumerate(dataset_iterator) engine.train() - - for step, batch_data in iterator_data: + + for step, batch_data in iterator_data: # batch_data = pretrain_dataset_provider.get_batch(batch_index) input_ids = batch_data[0].cuda(f"cuda:{torch.cuda.current_device()}") @@ -162,7 +168,7 @@ def main(): # nsp_label = batch_data[5].cuda() output = engine(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) - + loss = engine.criterion(output.logits, mlm_label) pretrain_dataset_provider.prefetch_batch() @@ -172,14 +178,15 @@ def main(): engine.step() lr_scheduelr.step() engine.zero_grad() - + global_step += 1 if global_step % args.log_interval == 0 and global_step != 0 \ - and torch.distributed.get_rank() == 0: + and torch.distributed.get_rank() == 0: elapsed_time = timers('interval_time').elapsed(reset=False) elapsed_time_per_iteration = elapsed_time / global_step - samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator(numel, args, config, elapsed_time, global_step, world_size) + samples_per_sec, tflops, approx_parameters_in_billions = throughput_calculator( + numel, args, config, elapsed_time, global_step, world_size) cur_loss = train_loss / args.log_interval current_lr = lr_scheduelr.get_last_lr()[0] @@ -189,12 +196,13 @@ def main(): if args.wandb: tensorboard_log = get_tensorboard_writer() - tensorboard_log.log_train({ - 'lr': current_lr, - 'loss': cur_loss, - 'ppl': math.exp(cur_loss), - 'mins_batch': elapsed_time_per_iteration - }, global_step) + tensorboard_log.log_train( + { + 'lr': current_lr, + 'loss': cur_loss, + 'ppl': math.exp(cur_loss), + 'mins_batch': elapsed_time_per_iteration + }, global_step) train_loss = 0 @@ -202,12 +210,14 @@ def main(): logger.info('*' * 100) eval_loss += evaluate(engine, args, logger, global_step) - save_ckpt(engine.model, optimizer, lr_scheduelr, os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, shard, global_step) - - + save_ckpt(engine.model, optimizer, lr_scheduelr, + os.path.join(args.ckpt_path, launch_time, f'epoch-{epoch}_shard-{shard}_' + launch_time), epoch, + shard, global_step) + eval_loss /= len(os.listdir(args.data_path_prefix)) - logger.info(f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + \ - f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') + logger.info( + f'epoch {epoch} | shard_length {len(os.listdir(args.data_path_prefix))} | elapsed_time: {timers("epoch_time").elapsed() / 60 :.3f} mins' + + f'eval_loss: {eval_loss} | ppl: {math.exp(eval_loss)}') logger.info('-' * 100) if args.wandb and torch.distributed.get_rank() == 0: tensorboard_log = get_tensorboard_writer() diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index c4f576cb1..e618b4d66 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,24 +30,13 @@ from itertools import chain import datasets import torch import torch.distributed as dist +import transformers from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset from packaging import version from torch.utils.data import DataLoader from tqdm.auto import tqdm - -import colossalai -import transformers -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -61,6 +50,15 @@ from transformers import ( ) from transformers.utils.versions import require_version +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer + require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) diff --git a/tests/test_auto_parallel/test_offload/test_perf.py b/tests/test_auto_parallel/test_offload/test_perf.py index 17bf9cb87..c925843fb 100644 --- a/tests/test_auto_parallel/test_offload/test_perf.py +++ b/tests/test_auto_parallel/test_offload/test_perf.py @@ -12,10 +12,9 @@ from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.fx.profiler import parameter_size from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.testing import parameterize from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py index 760401c3f..9879ae461 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_compatibility_with_gemini.py @@ -11,12 +11,11 @@ from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper from colossalai.tensor.process_group import ProcessGroup from colossalai.testing import assert_close, rerun_if_address_is_in_use from colossalai.testing.pytest_wrapper import run_on_environment_flag from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, post_process_colo_init_ctx, zero_model_wrapper, zero_optim_wrapper class MLP(torch.nn.Module): diff --git a/tests/test_ddp/test_ddp_ignore_params.py b/tests/test_ddp/test_ddp_ignore_params.py index 679c8b0f6..2ad20f6be 100644 --- a/tests/test_ddp/test_ddp_ignore_params.py +++ b/tests/test_ddp/test_ddp_ignore_params.py @@ -10,14 +10,14 @@ import torch.distributed as dist import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.parallel import ColoDDP, ZeroDDP +from colossalai.nn.parallel import ColoDDP from colossalai.tensor import ProcessGroup from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager def set_seed(seed): diff --git a/tests/test_ddp/test_ddp_state_dict.py b/tests/test_ddp/test_ddp_state_dict.py index f229364c6..bd4742ff2 100644 --- a/tests/test_ddp/test_ddp_state_dict.py +++ b/tests/test_ddp/test_ddp_state_dict.py @@ -1,18 +1,19 @@ import copy +from collections import OrderedDict +from functools import partial import pytest -import colossalai import torch import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from functools import partial -from tests.components_to_test.registry import non_distributed_component_funcs + +import colossalai from colossalai.nn.parallel import ColoDDP -from collections import OrderedDict -from colossalai.tensor import ProcessGroup, ColoParameter +from colossalai.tensor import ColoParameter, ProcessGroup +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): diff --git a/tests/test_gemini/test_gemini_manager.py b/tests/test_gemini/test_gemini_manager.py index 0c138f101..aee943253 100644 --- a/tests/test_gemini/test_gemini_manager.py +++ b/tests/test_gemini/test_gemini_manager.py @@ -1,73 +1,73 @@ -import pytest -import torch - -from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor - - -@pytest.mark.dist -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() +import pytest +import torch + +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState + + +@pytest.mark.dist +def test_gemini_manager(): + # reset the manager, in case that there exists memory information left + manager = StatefulTensor.GST_MGR + manager.reset() + + # occupation 8 + st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) + # occupation 60 + st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) + + # occupation 28 + t1 = torch.empty(7, device='cuda') + # occupation 12 + t2 = torch.empty(3, device='cpu') + st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) + st4 = StatefulTensor(None, TensorState.FREE) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 60 + assert manager.total_mem['cuda'] == 36 + assert manager.state_mem['cpu'][TensorState.HOLD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 + + st4.payload_reset(t2) + st3.payload_reset(t2) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 84 + assert manager.total_mem['cuda'] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD] == 72 + assert manager.state_mem['cuda'][TensorState.HOLD] == 8 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 + + st1.move_to(torch.device('cpu')) + st2.move_to(torch.device('cpu')) + st3.move_to(torch.device('cuda', 0)) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 80 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + + st1.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.COMPUTE) + st2.trans_state(TensorState.HOLD_AFTER_BWD) + + assert manager.total_number == 4 + assert manager.total_mem['cpu'] == 80 + assert manager.total_mem['cuda'] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD] == 12 + assert manager.state_mem['cuda'][TensorState.HOLD] == 0 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 + assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 + assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 + assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 + assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 + + +if __name__ == '__main__': + test_gemini_manager() diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py index daf386d6d..9ebacdb70 100644 --- a/tests/test_gemini/test_param_op.py +++ b/tests/test_gemini/test_param_op.py @@ -2,7 +2,7 @@ import copy import torch -from colossalai.gemini.paramhooks import BaseParamHookMgr +from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 294868458..9a3e93493 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -3,8 +3,8 @@ from copy import deepcopy import numpy as np import torch -from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_chunk_mgrv2.py b/tests/test_gemini/update/test_chunk_mgrv2.py index 7d192fc63..ba0945551 100644 --- a/tests/test_gemini/update/test_chunk_mgrv2.py +++ b/tests/test_gemini/update/test_chunk_mgrv2.py @@ -5,10 +5,10 @@ import torch import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager from colossalai.tensor import ColoTensor, ColoTensorSpec, ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port +from colossalai.zero.gemini.chunk import ChunkManager from tests.test_tensor.common_utils import debug_print CUDA_MEM_0 = {False: 512, True: 1024} diff --git a/tests/test_gemini/update/test_chunkv2.py b/tests/test_gemini/update/test_chunkv2.py index 96855410b..5f9ba5d3a 100644 --- a/tests/test_gemini/update/test_chunkv2.py +++ b/tests/test_gemini/update/test_chunkv2.py @@ -6,12 +6,12 @@ import torch.distributed as dist import torch.multiprocessing as mp import colossalai -from colossalai.gemini import TensorState -from colossalai.gemini.chunk import Chunk from colossalai.tensor import ColoParameter from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device +from colossalai.zero.gemini import TensorState +from colossalai.zero.gemini.chunk import Chunk def dist_sum(x): diff --git a/tests/test_gemini/update/test_fwd_bwd.py b/tests/test_gemini/update/test_fwd_bwd.py index 2821dc78d..8cfacd018 100644 --- a/tests/test_gemini/update/test_fwd_bwd.py +++ b/tests/test_gemini/update/test_fwd_bwd.py @@ -8,16 +8,14 @@ from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 8cf17a0a7..9d5419e94 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -5,15 +5,13 @@ import torch import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager +from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_gemini/update/test_get_torch_model.py b/tests/test_gemini/update/test_get_torch_model.py index e6d586b37..c014ced97 100644 --- a/tests/test_gemini/update/test_get_torch_model.py +++ b/tests/test_gemini/update/test_get_torch_model.py @@ -6,13 +6,12 @@ import torch import torch.multiprocessing as mp import colossalai -from colossalai.nn.parallel import GeminiDDP -from colossalai.nn.parallel.utils import get_static_torch_model from colossalai.tensor import ColoParameter from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiDDP +from colossalai.zero.gemini.utils import get_static_torch_model from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_grad_clip.py b/tests/test_gemini/update/test_grad_clip.py index d97ba9439..65f252c55 100644 --- a/tests/test_gemini/update/test_grad_clip.py +++ b/tests/test_gemini/update/test_grad_clip.py @@ -10,15 +10,13 @@ from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_inference.py b/tests/test_gemini/update/test_inference.py index b057448ad..12392d6e5 100644 --- a/tests/test_gemini/update/test_inference.py +++ b/tests/test_gemini/update/test_inference.py @@ -10,15 +10,13 @@ from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx, zero_model_wrapper +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_optim.py b/tests/test_gemini/update/test_optim.py index cd3aa6051..7364e59d1 100644 --- a/tests/test_gemini/update/test_optim.py +++ b/tests/test_gemini/update/test_optim.py @@ -9,16 +9,14 @@ from torch.testing import assert_close import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.tensor import ColoParameter, ColoTensor from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer, post_process_colo_init_ctx +from colossalai.zero.gemini.chunk import ChunkManager, init_chunk_manager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test import run_fwd_bwd from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_search.py b/tests/test_gemini/update/test_search.py index 2fcdd5380..71cdf9a18 100644 --- a/tests/test_gemini/update/test_search.py +++ b/tests/test_gemini/update/test_search.py @@ -6,11 +6,11 @@ import torch.distributed as dist import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.tensor import ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext +from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_gemini/update/test_zeroddp_state_dict.py b/tests/test_gemini/update/test_zeroddp_state_dict.py index 00d835842..7e759808d 100644 --- a/tests/test_gemini/update/test_zeroddp_state_dict.py +++ b/tests/test_gemini/update/test_zeroddp_state_dict.py @@ -7,13 +7,12 @@ import torch.multiprocessing as mp from torch.testing import assert_close import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_gemini/update/test_zerooptim_state_dict.py b/tests/test_gemini/update/test_zerooptim_state_dict.py index fd13af6b2..996dc4eb8 100644 --- a/tests/test_gemini/update/test_zerooptim_state_dict.py +++ b/tests/test_gemini/update/test_zerooptim_state_dict.py @@ -6,15 +6,13 @@ import torch.distributed as dist import torch.multiprocessing as mp import colossalai -from colossalai.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer -from colossalai.nn.parallel import ZeroDDP from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer +from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration +from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import debug_print, set_seed diff --git a/tests/test_moe/test_moe_checkpoint.py b/tests/test_moe/test_moe_checkpoint.py index f99e74ea5..5b6fe4411 100644 --- a/tests/test_moe/test_moe_checkpoint.py +++ b/tests/test_moe/test_moe_checkpoint.py @@ -11,7 +11,7 @@ from colossalai.context import MOE_CONTEXT from colossalai.nn.layer.moe import load_moe_model, save_moe_model from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_tensor.common_utils import debug_print from tests.test_zero.common import CONFIG diff --git a/tests/test_moe/test_moe_colo_init.py b/tests/test_moe/test_moe_colo_init.py index ae0c1390c..23ad1a3dc 100644 --- a/tests/test_moe/test_moe_colo_init.py +++ b/tests/test_moe/test_moe_colo_init.py @@ -1,63 +1,60 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.distributed as dist -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.tensor import ColoParameter -from colossalai.utils.model.colo_init_context import ColoInitContext - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device - -from tests.test_zero.common import CONFIG -from tests.test_moe.test_moe_zero_init import MoeModel -from tests.test_tensor.common_utils import debug_print - - -@parameterize("init_device_type", ['cpu', 'cuda']) -def exam_moe_colo_init(init_device_type): - world_size = dist.get_world_size() - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - with ColoInitContext(device=init_device): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) - - if hasattr(param, "moe_info"): - param.set_process_group(param.moe_info.pg) - - if hasattr(param, "moe_info"): - assert param.process_group.dp_world_size() == param.moe_info.dp_size - else: - assert param.process_group.dp_world_size() == world_size - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - exam_moe_colo_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [4]) -@rerun_if_address_is_in_use() -def test_moe_colo_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_colo_init(world_size=4) +from functools import partial + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.tensor import ColoParameter +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero import ColoInitContext +from tests.test_moe.test_moe_zero_init import MoeModel +from tests.test_tensor.common_utils import debug_print +from tests.test_zero.common import CONFIG + + +@parameterize("init_device_type", ['cpu', 'cuda']) +def exam_moe_colo_init(init_device_type): + world_size = dist.get_world_size() + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + with ColoInitContext(device=init_device): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert isinstance(param, ColoParameter), "parameter `{}` has an init problem".format(name) + + if hasattr(param, "moe_info"): + param.set_process_group(param.moe_info.pg) + + if hasattr(param, "moe_info"): + assert param.process_group.dp_world_size() == param.moe_info.dp_size + else: + assert param.process_group.dp_world_size() == world_size + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + exam_moe_colo_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [4]) +@rerun_if_address_is_in_use() +def test_moe_colo_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_colo_init(world_size=4) diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 04dc9c514..5987e31f7 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -1,114 +1,112 @@ -from functools import partial - -import colossalai -import pytest -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from colossalai.nn import CheckpointModule -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize -from colossalai.utils import free_port -from colossalai.context import MOE_CONTEXT -from colossalai.nn.layer import MoeModule -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) - -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils import get_current_device -from tests.test_zero.common import CONFIG - - -class MoeModel(nn.Module): - - def __init__(self, checkpoint: bool = False): - - class TestSubModule(CheckpointModule): - - def __init__(self): - super().__init__(checkpoint) - expert_cls = nn.Linear - expert_args_dict = dict(in_features=16, out_features=16) - self.moe = MoeModule(dim_model=16, - num_experts=8, - use_residual=True, - expert_cls=expert_cls, - **expert_args_dict) - self.proj = nn.Linear(16, 4) - - def _forward(self, x): - x, y = self.moe(x) - x = self.proj(x) - return x, y - - super().__init__() - self.test_embed = nn.Linear(4, 16) - self.test_transform = TestSubModule() - - def forward(self, x): - MOE_CONTEXT.reset_loss() - - x = self.test_embed(x) - x, y = self.test_transform(x) - - MOE_CONTEXT.add_loss(y) - return x - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_moe_zero_init(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_moe_zero_init") - - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - raise NotImplementedError("Unknown device found.") - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = MoeModel(checkpoint=True) - - for name, param in model.named_parameters(): - assert hasattr(param, 'colo_attr') - - # the parameters in moe experts and its gate should not be sharded - if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): - assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) - else: - assert param.colo_attr.sharded_data_tensor.is_sharded - - # the parameters in moe experts is not replicated - if 'experts' in name: - assert not param.colo_attr.is_replicated - else: - assert param.colo_attr.is_replicated - - if param.colo_attr.param_is_sharded: - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - else: - assert param.colo_attr.data_payload.device.type == 'cuda' - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - MOE_CONTEXT.setup(seed=42) - run_moe_zero_init() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_moe_zero_init(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_moe_zero_init(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +import torch.nn as nn + +import colossalai +from colossalai.context import MOE_CONTEXT +from colossalai.logging import get_dist_logger +from colossalai.nn import CheckpointModule +from colossalai.nn.layer import MoeModule +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port, get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from tests.test_zero.common import CONFIG + + +class MoeModel(nn.Module): + + def __init__(self, checkpoint: bool = False): + + class TestSubModule(CheckpointModule): + + def __init__(self): + super().__init__(checkpoint) + expert_cls = nn.Linear + expert_args_dict = dict(in_features=16, out_features=16) + self.moe = MoeModule(dim_model=16, + num_experts=8, + use_residual=True, + expert_cls=expert_cls, + **expert_args_dict) + self.proj = nn.Linear(16, 4) + + def _forward(self, x): + x, y = self.moe(x) + x = self.proj(x) + return x, y + + super().__init__() + self.test_embed = nn.Linear(4, 16) + self.test_transform = TestSubModule() + + def forward(self, x): + MOE_CONTEXT.reset_loss() + + x = self.test_embed(x) + x, y = self.test_transform(x) + + MOE_CONTEXT.add_loss(y) + return x + + +@parameterize("init_device_type", ['cpu', 'cuda']) +@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) +def run_moe_zero_init(init_device_type, shard_strategy_class): + logger = get_dist_logger("test_moe_zero_init") + + if init_device_type == 'cuda': + init_device = get_current_device() + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + raise NotImplementedError("Unknown device found.") + + model_numel_tensor = torch.zeros(1, dtype=torch.int) + with ZeroInitContext(target_device=init_device, + shard_strategy=shard_strategy_class(), + shard_param=True, + model_numel_tensor=model_numel_tensor): + model = MoeModel(checkpoint=True) + + for name, param in model.named_parameters(): + assert hasattr(param, 'colo_attr') + + # the parameters in moe experts and its gate should not be sharded + if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): + assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name) + else: + assert param.colo_attr.sharded_data_tensor.is_sharded + + # the parameters in moe experts is not replicated + if 'experts' in name: + assert not param.colo_attr.is_replicated + else: + assert param.colo_attr.is_replicated + + if param.colo_attr.param_is_sharded: + assert param.colo_attr.data_payload.device.type == init_device.type, \ + f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' + else: + assert param.colo_attr.data_payload.device.type == 'cuda' + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + MOE_CONTEXT.setup(seed=42) + run_moe_zero_init() + + +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [2, 4]) +@rerun_if_address_is_in_use() +def test_moe_zero_init(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_moe_zero_init(world_size=2) diff --git a/tests/test_moe/test_moe_zero_model.py b/tests/test_moe/test_moe_zero_model.py index d608ebf07..d38f66fef 100644 --- a/tests/test_moe/test_moe_zero_model.py +++ b/tests/test_moe/test_moe_zero_model.py @@ -10,11 +10,11 @@ from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.nn import MoeLoss from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_zero.common import CONFIG, check_grads_padding, run_fwd_bwd diff --git a/tests/test_moe/test_moe_zero_optim.py b/tests/test_moe/test_moe_zero_optim.py index 9d9a7bd17..7e140bf86 100644 --- a/tests/test_moe/test_moe_zero_optim.py +++ b/tests/test_moe/test_moe_zero_optim.py @@ -12,12 +12,12 @@ from colossalai.nn import MoeLoss from colossalai.nn.optimizer import CPUAdam from colossalai.testing import assert_equal_in_group, parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_moe.test_moe_zero_init import MoeModel from tests.test_zero.common import CONFIG, check_sharded_model_params diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py index d317dc2e3..ea1c044f5 100644 --- a/tests/test_optimizer/test_cpu_adam.py +++ b/tests/test_optimizer/test_cpu_adam.py @@ -56,7 +56,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype): eps = 1e-8 weight_decay = 0 - for i in range(1024): + for i in range(3): p_data = torch.rand(64, dtype=p_dtype) p_data_copy = p_data.clone().float() p_grad = torch.rand(64, dtype=g_dtype) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index 7b9b6e9c4..8ff6618ae 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -54,7 +54,7 @@ def test_adam(adamw, step, p_dtype, g_dtype): count = 0 - for i in range(1024): + for i in range(3): p = torch.rand(64, dtype=p_dtype).cuda() p_copy = p.clone().float() g = torch.rand(p.shape, dtype=g_dtype).cuda() diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py index d19192add..2576d8ffe 100644 --- a/tests/test_optimizer/test_hybrid_adam.py +++ b/tests/test_optimizer/test_hybrid_adam.py @@ -1,12 +1,12 @@ import torch import torch.nn as nn -from torch.optim.adam import Adam from torch.optim import AdamW +from torch.optim.adam import Adam from colossalai.nn.optimizer.hybrid_adam import HybridAdam from colossalai.testing import parameterize -RE = 1024 +RE = 3 @parameterize('adamw', [False, True]) diff --git a/tests/test_tensor/model/test_gpt2.py b/tests/test_tensor/model/test_gpt2.py index ad8ac87b2..0d6a3fe26 100644 --- a/tests/test_tensor/model/test_gpt2.py +++ b/tests/test_tensor/model/test_gpt2.py @@ -11,7 +11,7 @@ from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, Comput from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import ( debug_print, diff --git a/tests/test_tensor/model/test_model.py b/tests/test_tensor/model/test_model.py index 3f53b94e0..83abc641c 100644 --- a/tests/test_tensor/model/test_model.py +++ b/tests/test_tensor/model/test_model.py @@ -11,7 +11,7 @@ from colossalai.tensor.colo_parameter import ColoParameter from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import ( check_equal, diff --git a/tests/test_tensor/model/test_module_spec.py b/tests/test_tensor/model/test_module_spec.py index 997b416f1..739bf2b0a 100644 --- a/tests/test_tensor/model/test_module_spec.py +++ b/tests/test_tensor/model/test_module_spec.py @@ -20,7 +20,7 @@ from colossalai.tensor import ( from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py index 2f7aebed5..047371f45 100644 --- a/tests/test_tensor/test_context.py +++ b/tests/test_tensor/test_context.py @@ -17,7 +17,7 @@ from colossalai.tensor import ( from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed diff --git a/tests/test_tensor/test_tp_with_zero.py b/tests/test_tensor/test_tp_with_zero.py index 1a6d23f6a..94e39e5d1 100644 --- a/tests/test_tensor/test_tp_with_zero.py +++ b/tests/test_tensor/test_tp_with_zero.py @@ -7,14 +7,12 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.amp import convert_to_apex_amp -from colossalai.gemini.chunk import search_chunk_configuration -from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer -from colossalai.nn.parallel import GeminiDDP, ZeroDDP from colossalai.tensor import ColoTensor, ColoTensorSpec, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, GeminiDDP, ZeroDDP +from colossalai.zero.gemini import search_chunk_configuration from tests.components_to_test.registry import non_distributed_component_funcs from tests.test_tensor.common_utils import set_seed, tensor_shard_equal from tests.test_tensor.model.test_gpt2 import init_megatron_spec diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index a5ea75fff..7c2ad9078 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,25 +1,23 @@ -import os, shutil -import torch -import pytest +import os +import shutil from copy import deepcopy from functools import partial -import torch.multiprocessing as mp +import pytest +import torch import torch.distributed as dist - -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import MultiplicativeLR -from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR +import torch.multiprocessing as mp +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR import colossalai -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ComputePattern, ComputeSpec, ColoTensor, ShardSpec, ProcessGroup -from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import ColossalaiOptimizer - +from colossalai.tensor import ColoTensor, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint +from colossalai.utils.cuda import get_current_device +from colossalai.zero import ColoInitContext from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 0ecb7446c..6bfa6f33c 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,13 +1,12 @@ -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.zero.sharded_param import ShardedTensor -from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline -import colossalai - import torch - import torch.multiprocessing as mp +import colossalai +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.zero.legacy.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline +from colossalai.zero.legacy.sharded_param import ShardedTensor + def run_tensor_move(rank): colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') diff --git a/tests/test_utils/test_zero_gradient_clippling.py b/tests/test_utils/test_zero_gradient_clippling.py index 8bdae8846..920656726 100644 --- a/tests/test_utils/test_zero_gradient_clippling.py +++ b/tests/test_utils/test_zero_gradient_clippling.py @@ -2,21 +2,22 @@ # -*- encoding: utf-8 -*- import copy +from functools import partial -import colossalai -from colossalai.zero.sharded_model.sharded_model_v2 import ShardedModelV2 import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -from colossalai.logging import disable_existing_loggers -from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils import clip_grad_norm_ -from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy -from functools import partial + +import colossalai +from colossalai.logging import disable_existing_loggers from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import checkpoint, clip_grad_norm_fp32, free_port +from colossalai.zero.legacy.shard_utils.tensor_shard_strategy import TensorShardStrategy +from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 def checkpoint_wrapper(module, enable=True): diff --git a/tests/test_zero/common.py b/tests/test_zero/common.py index bc6cd75a6..2c3d122c7 100644 --- a/tests/test_zero/common.py +++ b/tests/test_zero/common.py @@ -2,10 +2,11 @@ from functools import partial import torch import torch.distributed as dist + from colossalai.logging import get_dist_logger from colossalai.utils import checkpoint -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 LOGGER = get_dist_logger('zero_test') diff --git a/tests/test_zero/low_level_zero/test_zero_init.py b/tests/test_zero/low_level_zero/test_zero_init.py index 1305da5df..803d0021d 100644 --- a/tests/test_zero/low_level_zero/test_zero_init.py +++ b/tests/test_zero/low_level_zero/test_zero_init.py @@ -9,8 +9,7 @@ import torch.nn as nn import colossalai from colossalai.tensor import ProcessGroup from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer class MlpModel(nn.Module): diff --git a/tests/test_zero/low_level_zero/test_zero_tp.py b/tests/test_zero/low_level_zero/test_zero_tp.py index 15d3530ff..bb7495583 100644 --- a/tests/test_zero/low_level_zero/test_zero_tp.py +++ b/tests/test_zero/low_level_zero/test_zero_tp.py @@ -11,8 +11,7 @@ import colossalai from colossalai.tensor import ProcessGroup from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port, get_current_device -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import LowLevelZeroOptimizer +from colossalai.zero import ColoInitContext, LowLevelZeroOptimizer from tests.test_tensor.common_utils import set_seed, split_param_col_tp1d, split_param_row_tp1d, tensor_shard_equal diff --git a/tests/test_zero/test_found_inf.py b/tests/test_zero/test_found_inf.py index 34283f501..641136718 100644 --- a/tests/test_zero/test_found_inf.py +++ b/tests/test_zero/test_found_inf.py @@ -1,72 +1,72 @@ -from functools import partial - -import colossalai -from colossalai.utils.cuda import get_current_device -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use -from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_zero.test_sharded_optim_v2 import _run_step - -from common import CONFIG - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - run_func = partial(_run_dist, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from common import CONFIG + +import colossalai +from colossalai.nn.optimizer import HybridAdam +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_zero.test_sharded_optim_v2 import _run_step + + +@parameterize("cpu_offload", [True, False]) +@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) +@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) +def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): + test_models = ['repeated_computed_layers'] + shard_strategy = shard_strategy_class() + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + + with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), + shard_strategy=shard_strategy, + shard_param=True): + zero_model = model_builder(checkpoint=True) + zero_model = ShardedModelV2( + zero_model, + shard_strategy, + tensor_placement_policy='cpu' if cpu_offload else 'cuda', + reuse_fp16_shard=True, + ) + + sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) + sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + assert zero_model.overflow_counter == 0 + data, label = data.cuda(), label.cuda() + _run_step(zero_model, sharded_optim, data, label, criterion, False) + for param in zero_model.parameters(): + assert not has_inf_or_nan(param.colo_attr.data_payload) + + +def _run_dist(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + _run_test_found_inf() + + +# use_cpuadam = True can be used with cpu_offload = False +@pytest.mark.dist +@pytest.mark.parametrize("world_size", [1, 2]) +@rerun_if_address_is_in_use() +def test_found_inf(world_size): + run_func = partial(_run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_found_inf(world_size=2) diff --git a/tests/test_zero/test_init_context.py b/tests/test_zero/test_init_context.py index 0cba7a492..0eb8842de 100644 --- a/tests/test_zero/test_init_context.py +++ b/tests/test_zero/test_init_context.py @@ -9,14 +9,14 @@ import torch.multiprocessing as mp from common import CONFIG import colossalai -from colossalai.gemini.memory_tracer.utils import colo_model_mem_usage from colossalai.logging import get_dist_logger from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device from colossalai.utils.memory import colo_device_memory_used -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_shard_model_v2.py b/tests/test_zero/test_shard_model_v2.py index 95a9dee38..884444adf 100644 --- a/tests/test_zero/test_shard_model_v2.py +++ b/tests/test_zero/test_shard_model_v2.py @@ -12,11 +12,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_shard_param.py b/tests/test_zero/test_shard_param.py index 8db2b7e79..6085de3c8 100644 --- a/tests/test_zero/test_shard_param.py +++ b/tests/test_zero/test_shard_param.py @@ -1,17 +1,18 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_param import ShardedTensor -from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_param import ShardedTensor +from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 from tests.test_zero.common import CONFIG, allclose -from colossalai.gemini.stateful_tensor import StatefulTensor @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) diff --git a/tests/test_zero/test_sharded_optim_state_dict.py b/tests/test_zero/test_sharded_optim_state_dict.py index f8c42930b..d257a0285 100644 --- a/tests/test_zero/test_sharded_optim_state_dict.py +++ b/tests/test_zero/test_sharded_optim_state_dict.py @@ -1,20 +1,21 @@ +from functools import partial + import pytest -import colossalai import torch import torch.multiprocessing as mp -from colossalai.testing import rerun_if_address_is_in_use -from colossalai.utils.cuda import get_current_device -from colossalai.utils import free_port -from functools import partial -from tests.test_tensor.common_utils import set_seed -from tests.components_to_test.registry import non_distributed_component_funcs -from colossalai.testing import parameterize + +import colossalai from colossalai.nn.optimizer import HybridAdam -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_optim import ShardedOptimizerV2 from colossalai.tensor import ProcessGroup +from colossalai.testing import parameterize, rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from tests.components_to_test.registry import non_distributed_component_funcs +from tests.test_tensor.common_utils import set_seed def init_zero(model_builder, placement_policy): diff --git a/tests/test_zero/test_sharded_optim_v2.py b/tests/test_zero/test_sharded_optim_v2.py index 8fe7eb639..3eea13d5d 100644 --- a/tests/test_zero/test_sharded_optim_v2.py +++ b/tests/test_zero/test_sharded_optim_v2.py @@ -13,12 +13,12 @@ from colossalai.nn.optimizer import CPUAdam from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.utils.cuda import get_current_device -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs diff --git a/tests/test_zero/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_sharded_optim_with_sync_bn.py index ea5b31518..05512f59a 100644 --- a/tests/test_zero/test_sharded_optim_with_sync_bn.py +++ b/tests/test_zero/test_sharded_optim_with_sync_bn.py @@ -3,18 +3,19 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from torchvision.models import resnet50 + +import colossalai from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import TensorShardStrategy -from torchvision.models import resnet50 +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import TensorShardStrategy def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_state_dict.py b/tests/test_zero/test_state_dict.py index 7ac9b151e..c435d9bb1 100644 --- a/tests/test_zero/test_state_dict.py +++ b/tests/test_zero/test_state_dict.py @@ -4,20 +4,20 @@ from copy import deepcopy from functools import partial -import colossalai import pytest import torch import torch.multiprocessing as mp +from common import CONFIG + +import colossalai from colossalai.testing import parameterize, rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy) -from colossalai.zero.sharded_model import ShardedModelV2 -from colossalai.zero.sharded_model.utils import col_model_deepcopy +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy +from colossalai.zero.legacy.sharded_model import ShardedModelV2 +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy from tests.components_to_test.registry import non_distributed_component_funcs -from common import CONFIG - @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) def run_zero_state_dict(shard_strategy_class): diff --git a/tests/test_zero/test_tensor_utils.py b/tests/test_zero/test_tensor_utils.py index 81855ff5e..311448170 100644 --- a/tests/test_zero/test_tensor_utils.py +++ b/tests/test_zero/test_tensor_utils.py @@ -1,18 +1,21 @@ +from functools import partial + import pytest +import torch +import torch.multiprocessing as mp import colossalai -from colossalai.utils.cuda import get_current_device -from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu, - colo_model_tensor_clone) -from colossalai.gemini.stateful_tensor import StatefulTensor -from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use - -import torch - -from functools import partial -import torch.multiprocessing as mp +from colossalai.utils import free_port +from colossalai.utils.cuda import get_current_device +from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor +from colossalai.zero.legacy.gemini.tensor_utils import ( + colo_model_data_move_to_cpu, + colo_model_data_tensor_move, + colo_model_data_tensor_move_inline, + colo_model_tensor_clone, + colo_tensor_mem_usage, +) def _run_colo_tensor_mem_usage(): diff --git a/tests/test_zero/test_zero_engine.py b/tests/test_zero/test_zero_engine.py index 80ded65d6..1e7f53358 100644 --- a/tests/test_zero/test_zero_engine.py +++ b/tests/test_zero/test_zero_engine.py @@ -3,21 +3,21 @@ from functools import partial -import colossalai import pytest import torch import torch.distributed as dist import torch.multiprocessing as mp +from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai from colossalai.core import global_context as gpc from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port -from colossalai.zero.init_ctx import ZeroInitContext -from colossalai.zero.sharded_model.utils import col_model_deepcopy -from colossalai.zero.sharded_optim._utils import has_inf_or_nan +from colossalai.zero.legacy.init_ctx import ZeroInitContext +from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy +from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -from torch.nn.parallel import DistributedDataParallel as DDP - -from common import (MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params) def run_dist(rank, world_size, port, parallel_config):