diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py deleted file mode 100644 index fe0308668..000000000 --- a/colossalai/utils/checkpoint_io/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .io import load, merge, redist, save -from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py deleted file mode 100644 index 140192c05..000000000 --- a/colossalai/utils/checkpoint_io/backend.py +++ /dev/null @@ -1,74 +0,0 @@ -import shutil -import tempfile -from abc import ABC, abstractmethod -from typing import Dict, List, Type - -from .reader import CheckpointReader, DiskCheckpointReader -from .writer import CheckpointWriter, DiskCheckpointWriter - -_backends: Dict[str, Type['CheckpointIOBackend']] = {} - - -def register(name: str): - assert name not in _backends, f'"{name}" is registered' - - def wrapper(cls): - _backends[name] = cls - return cls - - return wrapper - - -def get_backend(name: str) -> 'CheckpointIOBackend': - assert name in _backends, f'Unsupported backend "{name}"' - return _backends[name]() - - -class CheckpointIOBackend(ABC): - - def __init__(self) -> None: - super().__init__() - self.temps: List[str] = [] - - @abstractmethod - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - pass - - @abstractmethod - def get_reader(self, base_name: str) -> CheckpointReader: - pass - - @abstractmethod - def get_temp(self, base_name: str) -> str: - pass - - @abstractmethod - def clean_temp(self) -> None: - pass - - -@register('disk') -class CheckpointDiskIO(CheckpointIOBackend): - - def get_writer(self, - base_name: str, - overwrite: bool = False, - rank: int = 0, - world_size: int = 1) -> CheckpointWriter: - return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) - - def get_reader(self, base_name: str) -> CheckpointReader: - return DiskCheckpointReader(base_name) - - def get_temp(self, base_name: str) -> str: - temp_dir_name = tempfile.mkdtemp(dir=base_name) - self.temps.append(temp_dir_name) - return temp_dir_name - - def clean_temp(self) -> None: - for temp_dir_name in self.temps: - shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py deleted file mode 100644 index 219948474..000000000 --- a/colossalai/utils/checkpoint_io/constant.py +++ /dev/null @@ -1,9 +0,0 @@ -import re - -GLOBAL_META_FILE_NAME = 'global_meta.bin' -MODEL_CKPT_FILE_NAME = 'model.bin' -OPTIM_CKPT_FILE_NAME = 'optim.bin' -META_CKPT_FILE_NAME = 'meta.bin' -OTHER_CKPT_FILE_NAME = 'other.bin' - -CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py deleted file mode 100644 index 529ceb868..000000000 --- a/colossalai/utils/checkpoint_io/convertor.py +++ /dev/null @@ -1,227 +0,0 @@ -from abc import ABC, abstractmethod -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional - -from torch import Tensor - -from .distributed import merge_param, unmerge_param -from .meta import ParamDistMeta, RedistMeta -from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) - - -class CheckpointConvertor(ABC): - - @abstractmethod - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - pass - - @abstractmethod - def complete(self) -> None: - pass - - -class ModelCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int]) -> None: - super().__init__() - self.param_count = param_count - self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) - - @abstractmethod - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - for k, tensor in state_dict.items(): - self.buffer[k][rank] = tensor - converted_keys = set() - for k, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[k]: - tensors = [] - dist_metas = [] - for rank, tensor in rank_dict.items(): - tensors.append(tensor) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][k]) - self.convert_tensors(k, tensors, dist_metas) - converted_keys.add(k) - for k in converted_keys: - del self.buffer[k] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class ModelCheckpointMerger(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: - super().__init__(param_count) - self.sharder = ModelCheckpointSharder(max_shard_size) - self.save_fn = save_fn - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - shard = self.sharder.append(key, tensor) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class ModelCheckpointRedistor(ModelCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - redist_meta: RedistMeta) -> None: - super().__init__(param_count) - self.save_fns = save_fns - self.redist_meta = redist_meta - nprocs = len(save_fns) - self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: - if len(dist_metas) == 0: - # already global - tensor = tensors[0] - else: - assert len(dist_metas) == len(tensors) - tensor = merge_param(tensors, dist_metas) - for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[key][tp_rank][dp_rank]: - shard = self.sharders[rank].append(key, t) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) - - -class OptimizerCheckpointConvertor(CheckpointConvertor): - - def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], - paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__() - self.param_count = param_count - self.param_to_os = param_to_os - self.paired_os = paired_os - self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) - self.os_to_param = {v: k for k, v in param_to_os.items()} - - @abstractmethod - def setup(self, param_groups: dict) -> None: - pass - - @abstractmethod - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - pass - - def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for rank, state_dict in shard_dict.items(): - self.setup(state_dict['param_groups']) - for idx, state in state_dict['state'].items(): - self.buffer[idx][rank] = state - converted_indices = set() - for idx, rank_dict in self.buffer.items(): - if len(rank_dict) == self.param_count[self.os_to_param[idx]]: - states = [] - dist_metas = [] - for rank, state in rank_dict.items(): - states.append(state) - if dist_meta_list[rank] is not None: - dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) - self.convert_states(idx, states, dist_metas) - converted_indices.add(idx) - for idx in converted_indices: - del self.buffer[idx] - - def complete(self) -> None: - assert len(self.buffer) == 0 - - -class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fn = save_fn - self.sharder = None - - def setup(self, param_groups: dict) -> None: - if self.sharder is None: - self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) == len(states) - new_state = {} - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) - else: - new_state[state_key] = state_tensor - shard = self.sharder.append(idx, new_state) - run_if_not_none(self.save_fn, shard) - - def complete(self) -> None: - super().complete() - run_if_not_none(self.save_fn, self.sharder.complete()) - - -class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): - - def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], - param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], - redist_meta: RedistMeta) -> None: - super().__init__(param_count, param_to_os, paired_os) - self.max_shard_size = max_shard_size - self.save_fns = save_fns - self.redist_meta = redist_meta - self.sharders: List[OptimizerCheckpointSharder] = [] - self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) - for k, rank_meta in redist_meta.rank_meta.items(): - for rank, rank_info in rank_meta.items(): - self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) - - def setup(self, param_groups: dict) -> None: - if len(self.sharders) == 0: - nprocs = len(self.save_fns) - for _ in range(nprocs): - self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) - - def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: - need_merge: bool = True - if len(dist_metas) == 0: - need_merge = False - else: - assert len(dist_metas) == len(states) - new_states = [{} for _ in range(len(self.save_fns))] - for state_key, state_tensor in states[0].items(): - if self.paired_os[idx][state_key]: - if need_merge: - tensor = merge_param([state[state_key] for state in states], dist_metas) - else: - tensor = state_tensor - for tp_rank, tensor_list in enumerate( - unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): - for dp_rank, t in enumerate(tensor_list): - for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: - new_states[rank][state_key] = t - else: - for new_state in new_states: - new_state[state_key] = state_tensor - for rank, new_state in enumerate(new_states): - shard = self.sharders[rank].append(idx, new_state) - run_if_not_none(self.save_fns[rank], shard) - - def complete(self) -> None: - super().complete() - for rank, save_fn in enumerate(self.save_fns): - run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py deleted file mode 100644 index bf720437c..000000000 --- a/colossalai/utils/checkpoint_io/distributed.py +++ /dev/null @@ -1,127 +0,0 @@ -import torch -from numpy import prod -from torch import Tensor -from typing import List, Optional, Tuple -from collections import defaultdict -from .meta import ParamDistMeta, ParamRedistMeta - - -def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' - if not dist_metas[0].used_zero: - # tensors are replicate - return tensors[0] - numel = dist_metas[0].zero_numel - orig_shape = dist_metas[0].zero_orig_shape - tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] - assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' - return torch.cat(tensors).reshape(orig_shape) - - -def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - for dist_meta in dist_metas[1:]: - assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' - for t in tensors[1:]: - assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' - if not dist_metas[0].used_tp: - # tensors are replicate - return tensors[0] - total_parts = prod(dist_meta.tp_num_parts) - assert dist_meta.tp_world_size == total_parts, \ - f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' - shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) - for dim, num_parts in shard_info: - buffer = [] - for start in range(0, len(tensors), num_parts): - buffer.append(torch.cat(tensors[start:start + num_parts], dim)) - tensors = buffer - assert len(tensors) == 1 - return tensors[0] - - -def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: - assert len(dist_metas) > 0 - # check world size - for dist_meta in dist_metas[1:]: - assert dist_meta.dp_world_size == dist_metas[ - 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' - assert dist_meta.tp_world_size == dist_metas[ - 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' - - -def deduplicate_params(tensors: List[Tensor], - dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: - unique_dist_meta = [] - unique_idx = [] - for i, dist_meta in enumerate(dist_metas): - if dist_meta not in unique_dist_meta: - unique_dist_meta.append(dist_meta) - unique_idx.append(i) - return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] - - -def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: - assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) - # validate parallel info - validate_parallel_info(dist_metas) - tensors, dist_metas = deduplicate_params(tensors, dist_metas) - unflattened_tensors = [] - # group zero params by tp rank - tensor_dict = defaultdict(list) - dist_meta_dict = defaultdict(list) - for t, dist_meta in zip(tensors, dist_metas): - tensor_dict[dist_meta.tp_rank].append(t) - dist_meta_dict[dist_meta.tp_rank].append(dist_meta) - assert len(tensor_dict - ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' - for tp_rank in tensor_dict.keys(): - unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) - return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) - - -def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_tp: - assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' - return [tensor] - total_parts = prod(redist_meta.tp_num_parts) - assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' - shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) - tensors = [tensor] - for dim, num_parts in shard_info: - buffer = [] - for t in tensors: - assert t.size(dim) % num_parts == 0, \ - f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' - chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] - buffer.extend(chunks) - tensors = buffer - assert len(tensors) == redist_meta.tp_world_size - return tensors - - -def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: - if not redist_meta.used_zero: - return [tensor] * redist_meta.dp_world_size - tensors: List[Optional[Tensor]] = [ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) - ] - offsets = redist_meta.zero_offsets + [tensor.numel()] - for i, offset in enumerate(offsets[:-1]): - end = offsets[i + 1] - tensors.append(tensor.view(-1)[offset:end]) - if len(tensors) < redist_meta.dp_world_size: - tensors.extend([ - torch.empty(0, dtype=tensor.dtype, device=tensor.device) - for _ in range(redist_meta.dp_world_size - len(tensors)) - ]) - assert len(tensors) == redist_meta.dp_world_size - return tensors - - -def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: - tensors = split_tp_param(tensor, redist_meta) - tensors = [flatten_zero_param(t, redist_meta) for t in tensors] - return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py deleted file mode 100644 index f00212cdf..000000000 --- a/colossalai/utils/checkpoint_io/io.py +++ /dev/null @@ -1,170 +0,0 @@ -import warnings -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple - -import torch.distributed as dist -from torch.nn import Module -from torch.optim import Optimizer - -from .backend import get_backend -from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, - OptimizerCheckpointRedistor) -from .meta import ParamDistMeta, RedistMeta -from .utils import build_checkpoints, optimizer_load_state_dict - - -def save(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk', - **kwargs: Any) -> None: - io_backend = get_backend(backend) - if dist.is_initialized(): - rank = dist.get_rank() - world_size = dist.get_world_size() - else: - rank = 0 - world_size = 1 - if world_size == 1: - # global doesn't need dist_meta - dist_meta = None - else: - assert dist_meta is not None - max_shard_size = int(max_shard_size_gb * 1024**3) - model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, - param_to_os, dist_meta) - writer = io_backend.get_writer(path, overwrite, rank, world_size) - writer.save_others(kwargs) - for model_checkpoint in model_checkpoints: - writer.save_model(model_checkpoint) - for optimizer_checkpoint in optimizer_checkpoints: - writer.save_optimizer(optimizer_checkpoint) - writer.save_meta(meta_checkpoint) - - -def merge(path: str, - output_path: str, - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - reader = io_backend.get_reader(path) - if len(reader.meta_list) == 1: - # already global - warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') - return False - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - writer = io_backend.get_writer(output_path, overwrite=overwrite) - writer.save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), - dist_meta_list) - _convert_shards( - OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), - reader.load_optimizers(), dist_meta_list) - meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def redist(path: str, - output_path: str, - redist_meta: RedistMeta, - dist_metas: List[Dict[str, ParamDistMeta]], - max_shard_size_gb: float = 0.0, - overwrite: bool = False, - backend: str = 'disk') -> bool: - io_backend = get_backend(backend) - if dist.is_initialized() and dist.get_rank() != 0: - return False - nprocs = len(dist_metas) - reader = io_backend.get_reader(path) - dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() - do_redist: bool = False - if len(dist_meta_list) == nprocs: - for a, b in zip(dist_metas, dist_meta_list): - if a != b: - do_redist = True - break - else: - do_redist = True - if not do_redist: - warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') - return False - - writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] - writers[0].save_others(reader.load_others()) - max_shard_size = int(max_shard_size_gb * 1024**3) - _convert_shards( - ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), - reader.load_models(), dist_meta_list) - _convert_shards( - OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, - param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) - for writer, dist_meta in zip(writers, dist_metas): - meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} - if param_to_os is not None: - meta_checkpoint['param_to_os'] = param_to_os - meta_checkpoint['paired_os'] = paired_os - writer.save_meta(meta_checkpoint) - return True - - -def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], - dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: - for shard_dict in shard_generator: - convertor.append(shard_dict, dist_meta_list) - convertor.complete() - - -def load(path: str, - model: Module, - optimizer: Optional[Optimizer] = None, - redist_meta: Optional[RedistMeta] = None, - dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, - max_shard_size_gb: float = 0.0, - backend: str = 'disk') -> dict: - is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 - rank: int = dist.get_rank() if dist.is_initialized() else 0 - is_main_process: bool = rank == 0 - # validate args - if redist_meta is None or dist_metas is None: - assert is_global - io_backend = get_backend(backend) - read_path: str = path - if is_main_process: - # pre-process checkpoints - temp_path = io_backend.get_temp(path) - if is_global: - wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) - else: - wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) - if wrote: - read_path = temp_path - if not is_global: - bcast_list = [read_path] if is_main_process else [None] - dist.broadcast_object_list(bcast_list) - read_path = bcast_list[0] - reader = io_backend.get_reader(read_path) - # load model - for shard in reader.load_model(rank): - model.load_state_dict(shard, strict=False) - if optimizer is not None: - for shard in reader.load_optimizer(rank): - # optimizer.load_state_dict(shard) - optimizer_load_state_dict(optimizer, shard) - others_dict = reader.load_others() - if not is_global: - dist.barrier() - # clean up temp - if is_main_process: - io_backend.clean_temp() - return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py deleted file mode 100644 index 994f08b4b..000000000 --- a/colossalai/utils/checkpoint_io/meta.py +++ /dev/null @@ -1,81 +0,0 @@ -from dataclasses import dataclass -from typing import List, Optional, Set, Dict - - -@dataclass -class ParamDistMeta: - # parallel info - dp_rank: int - dp_world_size: int - tp_rank: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_numel: Optional[int] = None - zero_orig_shape: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_numel is not None and self.zero_orig_shape is not None - - @property - def parallel_meta(self) -> tuple: - return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size - - @property - def tp_meta(self) -> tuple: - return self.tp_shard_dims, self.tp_num_parts - - @property - def zero_meta(self) -> tuple: - return self.zero_numel, self.zero_orig_shape - - @staticmethod - def from_dict(d: dict) -> 'ParamDistMeta': - return ParamDistMeta(**d) - - -@dataclass -class ParamRedistMeta: - # parallel info - dp_world_size: int - tp_world_size: int - # tp info - tp_shard_dims: Optional[List[int]] = None - tp_num_parts: Optional[List[int]] = None - # zero info - zero_start_dp_rank: Optional[int] = None - zero_offsets: Optional[List[int]] = None - - @property - def used_tp(self) -> bool: - return self.tp_shard_dims is not None and self.tp_num_parts is not None - - @property - def used_zero(self) -> bool: - return self.zero_start_dp_rank is not None and self.zero_offsets is not None - - -@dataclass -class RankRedistMeta: - dp_rank: int - tp_rank: int - pp_rank: int - - -@dataclass -class PipelineRedistMeta: - params: Set[str] - - -@dataclass -class RedistMeta: - rank_meta: Dict[str, Dict[int, RankRedistMeta]] - pipeline_meta: List[PipelineRedistMeta] - param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py deleted file mode 100644 index 3158c6481..000000000 --- a/colossalai/utils/checkpoint_io/reader.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -from abc import ABC, abstractmethod -from collections import Counter -from typing import Dict, Generator, List, Optional, Tuple - -import torch - -from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME -from .meta import ParamDistMeta -from .utils import is_duplicated_list - - -class CheckpointReader(ABC): - - def __init__(self, base_name: str) -> None: - super().__init__() - self.base_name = base_name - self.meta_list = [] - - @abstractmethod - def read(self, name: str) -> dict: - pass - - @abstractmethod - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - pass - - @abstractmethod - def load_model(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_models(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - pass - - @abstractmethod - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - pass - - @abstractmethod - def load_others(self) -> dict: - pass - - -class DiskCheckpointReader(CheckpointReader): - - def __init__(self, base_name: str) -> None: - super().__init__(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - global_meta = self.read(GLOBAL_META_FILE_NAME) - for meta_file_name in global_meta['meta']: - meta = self.read(meta_file_name) - if meta.get('dist_meta', None) is None: - # only global checkpoint can have empty dist_meta - assert len(global_meta['meta']) == 1 - self.meta_list.append(meta) - - def read(self, name: str) -> dict: - return torch.load(os.path.join(self.base_name, name)) - - def load_meta( - self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: - meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', - None), meta.get('paired_os', None)) - for meta in self.meta_list] - dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) - # reduce param_count - param_count = Counter(p for params in params_list for p in params) - # validate param_to_os - assert is_duplicated_list(param_to_os_list) - assert is_duplicated_list(paired_os_list) - return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] - - def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: - meta = self.meta_list[rank] - checkpoint_names = meta.get(shard_type, []) - for name in checkpoint_names: - yield self.read(name) - - def load_model(self, rank: int) -> Generator[dict, None, None]: - return self._load_shard('model', rank) - - def load_models(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - model_checkpoint_names = meta.get('model', []) - if indices[i] < len(model_checkpoint_names): - shards[i] = self.read(model_checkpoint_names[indices[i]]) - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_optimizer(self, rank: int) -> Generator[dict, None, None]: - param_groups = None - for shard in self._load_shard('optimizer', rank): - if param_groups is None: - param_groups = shard['param_groups'] - else: - shard['param_groups'] = param_groups - yield shard - - def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: - indices = [0] * len(self.meta_list) - param_groups = [] - while True: - shards = {} - for i, meta in enumerate(self.meta_list): - optimizer_checkpoint_names = meta.get('optimizer', []) - if indices[i] < len(optimizer_checkpoint_names): - shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) - if indices[i] == 0: - param_groups.append(shards[i]['param_groups']) - else: - shards[i]['param_groups'] = param_groups[i] - indices[i] += 1 - if len(shards) > 0: - yield shards - else: - break - - def load_others(self) -> dict: - return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py deleted file mode 100644 index 135385f57..000000000 --- a/colossalai/utils/checkpoint_io/utils.py +++ /dev/null @@ -1,223 +0,0 @@ -import warnings -from copy import deepcopy -from itertools import chain -from typing import Any, Callable, Dict, List, Optional, Tuple - -from torch import Tensor -from torch.nn import Module -from torch.nn.parameter import Parameter -from torch.optim import Optimizer - -from .meta import ParamDistMeta - - -def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: - if arg is not None: - return fn(arg) - - -def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: - # ensure all params in optimizer are in model state dict - params_set = set(id(p) for p in model.parameters()) - for group in optimizer.param_groups: - for p in group['params']: - assert id(p) in params_set - param_mappings = {} - start_index = 0 - - def get_group_mapping(group): - nonlocal start_index - param_mappings.update( - {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) - start_index += len(group['params']) - - for g in optimizer.param_groups: - get_group_mapping(g) - return {k: param_mappings[id(p)] for k, p in model.named_parameters()} - - -def compute_optimizer_state_size(state: Dict[str, Any]) -> int: - size = 0 - for v in state.values(): - if isinstance(v, Tensor): - size += v.numel() * v.element_size() - return size - - -class ModelCheckpointSharder: - - def __init__(self, max_shard_size: int) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, Tensor] = {} - self.buffer_size: int = 0 - - def append(self, key: str, tensor: Tensor) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {} - self.buffer_size = 0 - self.buffer[key] = tensor - self.buffer_size += tensor.numel() * tensor.element_size() - return retval - - def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: - shards = [] - for key, tensor in state_dict.items(): - shard = self.append(key, tensor) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer) > 0 else None - - -class OptimizerCheckpointSharder: - - def __init__(self, max_shard_size: int, param_groups: dict) -> None: - self.max_shard_size = max_shard_size - self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} - self.buffer_size: int = 0 - self.returned_first: bool = False - - def append(self, key: int, state: dict) -> Optional[dict]: - retval = None - if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: - retval = self.buffer - self.buffer = {'state': {}} - self.buffer_size = 0 - self.buffer['state'][key] = state - self.buffer_size += compute_optimizer_state_size(state) - return retval - - def extend(self, state_dict: Dict[str, dict]) -> List[dict]: - shards = [] - for key, state in state_dict['state'].items(): - shard = self.append(key, state) - run_if_not_none(shards.append, shard) - return shards - - def complete(self) -> Optional[dict]: - return self.buffer if len(self.buffer['state']) > 0 else None - - -def shard_checkpoint(max_shard_size: int, - model_state_dict: Dict[str, Tensor], - optimizer_state_dict: Optional[dict] = None, - param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: - has_optimizer: bool = False - if optimizer_state_dict is not None: - assert param_to_os is not None - os_to_param = {v: k for k, v in param_to_os.items()} - for os_key in optimizer_state_dict['state'].keys(): - assert os_key in os_to_param - assert os_to_param[os_key] in model_state_dict - has_optimizer = True - model_sharder = ModelCheckpointSharder(max_shard_size) - model_shards = model_sharder.extend(model_state_dict) - run_if_not_none(model_shards.append, model_sharder.complete()) - if not has_optimizer: - return model_shards, [] - optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) - optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) - run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) - return model_shards, optimizer_shards - - -def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: - os_to_param = {v: k for k, v in param_to_os.items()} - paired_os = {} - for idx, state in optimizer_state_dict['state'].items(): - paired_os[idx] = {} - p = model_state_dict[os_to_param[idx]] - for k, v in state.items(): - if isinstance(v, Tensor) and v.shape == p.shape: - paired_os[idx][k] = True - else: - paired_os[idx][k] = False - return paired_os - - -def build_checkpoints(max_size: int, - model: Module, - optimizer: Optional[Optimizer] = None, - param_to_os: Optional[Dict[str, int]] = None, - dist_meta: Optional[Dict[str, ParamDistMeta]] = None, - eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: - save_global = dist_meta is None - model_state_dict = model.state_dict() - optimizer_state_dict = optimizer.state_dict() if optimizer else None - meta = {'dist_meta': dist_meta} - if optimizer: - param_to_os = param_to_os or get_param_to_os(model, optimizer) - paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) - meta['param_to_os'] = param_to_os - meta['paired_os'] = paired_os - if not save_global and eliminate_replica: - # filter dp replicated params - model_state_dict = { - k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - if optimizer: - optimizer_state_dict['state'] = { - param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] - for k in model_state_dict.keys() - if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 - } - meta['params'] = list(model_state_dict.keys()) - if len(model_state_dict) == 0: - warnings.warn('model state dict is empty, checkpoint is not saved') - return [], [], meta - model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, - param_to_os) - return model_checkpoints, optimizer_checkpoints, meta - - -def is_duplicated_list(list_: List[Any]) -> bool: - if len(list_) == 0: - return True - elem = list_[0] - for x in list_[1:]: - if x != elem: - return False - return True - - -def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: - for k, v in src_state.items(): - if k in dest_state: - old_v = dest_state[k] - if isinstance(old_v, Tensor): - old_v.copy_(v) - else: - dest_state[k] = v - - -def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: - assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] - state_dict = deepcopy(state_dict) - groups = optimizer.param_groups - saved_groups = state_dict['param_groups'] - idx_to_p: Dict[str, Parameter] = { - old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups - )), chain.from_iterable((g['params'] for g in groups))) - } - missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) - unexpected_keys = [] - error_msgs = [] - for idx, state in state_dict['state'].items(): - if idx in idx_to_p: - old_state = optimizer.state[idx_to_p[idx]] - copy_optimizer_state(state, old_state) - else: - unexpected_keys.append(idx) - if strict: - if len(unexpected_keys) > 0: - error_msgs.insert( - 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) - if len(missing_keys) > 0: - error_msgs.insert( - 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) - if len(error_msgs) > 0: - raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, - "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py deleted file mode 100644 index 4552accde..000000000 --- a/colossalai/utils/checkpoint_io/writer.py +++ /dev/null @@ -1,98 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Optional -from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME -import torch -import os - - -class CheckpointWriter(ABC): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__() - self.base_name = base_name - self.overwrite = overwrite - self.rank = rank - self.world_size = world_size - self.is_distributed = world_size > 1 - self.is_main_process = rank == 0 - - @abstractmethod - def write(self, name: str, state_dict: dict) -> None: - pass - - @abstractmethod - def save_model(self, model_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_meta(self, meta_checkpoint: dict) -> None: - pass - - @abstractmethod - def save_others(self, kwargs: dict) -> None: - pass - - -class DiskCheckpointWriter(CheckpointWriter): - - def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: - super().__init__(base_name, overwrite, rank, world_size) - if not os.path.exists(base_name): - os.makedirs(base_name) - assert os.path.isdir(base_name), f'"{base_name}" is not a directory' - self.model_checkpoint_names = [] - self.optimizer_checkpoint_names = [] - self.is_meta_saved: bool = False - self._save_global_meta() - - def write(self, name: str, state_dict: dict) -> None: - path = os.path.join(self.base_name, name) - if os.path.exists(path) and not self.overwrite: - raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') - torch.save(state_dict, path) - - def _save_global_meta(self) -> None: - if self.is_main_process: - global_meta = {'meta': []} - if self.is_distributed: - for i in range(self.world_size): - global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) - else: - global_meta['meta'].append(META_CKPT_FILE_NAME) - self.write(GLOBAL_META_FILE_NAME, global_meta) - - def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: - checkpoint_name = base_name - if self.is_distributed: - checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') - if shard_idx is not None: - checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') - return checkpoint_name - - def save_model(self, model_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save model after saving meta' - name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) - self.write(name, model_checkpoint) - self.model_checkpoint_names.append(name) - - def save_optimizer(self, optimizer_checkpoint: dict) -> None: - assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' - name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) - self.write(name, optimizer_checkpoint) - self.optimizer_checkpoint_names.append(name) - - def save_meta(self, meta_checkpoint: dict) -> None: - if len(self.model_checkpoint_names) > 0: - meta_checkpoint['model'] = self.model_checkpoint_names - if len(self.optimizer_checkpoint_names) > 0: - meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names - self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) - self.is_meta_saved = True - - def save_others(self, kwargs: dict) -> None: - if self.is_main_process: - self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py deleted file mode 100644 index 622d9deb6..000000000 --- a/tests/test_lazy/test_distribute.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Optional - -import pytest -import torch -import torch.nn as nn - -import colossalai -from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.common import print_rank_0 - -try: - from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -except: - pass -from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed - -from tests.kit.model_zoo import model_zoo - - -def find_shard_dim(shape: torch.Size) -> Optional[int]: - for dim, size in enumerate(shape): - if size % 2 == 0: - return dim - - -def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: - shard_dim = find_shard_dim(original_tensor.shape) - dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} - target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) - return target_sharding_spec - - -def _get_current_name(prefix: str, name: str) -> str: - return f'{prefix}.{name}'.lstrip('.') - - -def generate_sharding_spec_dict(model: nn.Module) -> dict: - sharding_spec_dict = {} - - @torch.no_grad() - def generate_recursively(module: nn.Module, prefix: str = ''): - # recursively initialize the module - for name, mod in module.named_children(): - generate_recursively(mod, prefix=_get_current_name(prefix, name)) - - # initialize tensors directly attached to the current module - for name, param in module.named_parameters(recurse=False): - if isinstance(param, LazyTensor): - sharding_spec = make_sharding_spec(param) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec - - for name, buf in module.named_buffers(recurse=False): - if isinstance(buf, LazyTensor): - sharding_spec = make_sharding_spec(buf) - sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec - - generate_recursively(model) - - return sharding_spec_dict - - -@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) -def run_dist_lazy_init(subset, seed: int = 42): - sub_model_zoo = model_zoo.get_sub_registry(subset) - device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) - _MyTensor._pre_op_fn = lambda *args: set_seed(seed) - LazyTensor._pre_op_fn = lambda *args: set_seed(seed) - - for name, entry in sub_model_zoo.items(): - # TODO(ver217): lazy init does not support weight norm, skip these models - if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base') or name.startswith('transformers_llama'): - continue - print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry - ctx = LazyInitContext(tensor_cls=_MyTensor) - with ctx: - model = model_fn() - ctx = LazyInitContext() - with ctx: - deferred_model = model_fn() - sharding_spec_dict = generate_sharding_spec_dict(deferred_model) - ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) - assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) - - -def run_dist(rank, world_size, port) -> None: - colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port) - run_dist_lazy_init() - - -@pytest.mark.skipif(not SUPPORT_LAZY, reason='torch version should be >= 1.12.0') -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_dist_lazy_init(): - spawn(run_dist, 4) - - -if __name__ == '__main__': - test_dist_lazy_init() diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py deleted file mode 100644 index 6d89fb90c..000000000 --- a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import torch.nn as nn -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.utils import build_checkpoints -from torch.optim import Adam - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def test_global_model(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - global_state_dict = model_checkpoints[0] - assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) - for k, v in orig_state_dict.items(): - assert torch.equal(v, global_state_dict[k]) - - -def test_global_model_shard(): - model = DummyModel() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) - assert len(model_checkpoints) == 2 - assert len(optimizer_checkpoints) == 0 - assert meta['dist_meta'] is None - orig_state_dict = model.state_dict() - assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) - assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 - for k, v in orig_state_dict.items(): - for state_dict in model_checkpoints: - if k in state_dict: - assert torch.equal(v, state_dict[k]) - - -def test_global_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) - assert len(optimizer_checkpoints) == 1 - assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} - for state in meta['paired_os'].values(): - for k, is_paired in state.items(): - if k == 'step': - assert not is_paired - else: - assert is_paired - orig_state_dict = optimizer.state_dict() - state_dict = optimizer_checkpoints[0] - for k, orig_state in orig_state_dict['state'].items(): - state = state_dict['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v2 == v2 - assert orig_state_dict['param_groups'] == state_dict['param_groups'] - - -def test_global_optimizer_shard(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) - assert len(optimizer_checkpoints) == 2 - assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] - orig_state_dict = optimizer.state_dict() - assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( - optimizer_checkpoints[1]['state'].keys()) - assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 - for k, orig_state in orig_state_dict['state'].items(): - state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ - 'state'] else optimizer_checkpoints[1]['state'][k] - for v1, v2 in zip(orig_state.values(), state.values()): - if isinstance(v2, torch.Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - - assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] - - -def test_dist_model_optimizer(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] - assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] - dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} - model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) - assert dist_meta == meta['dist_meta'] - assert len(model_checkpoints) == 1 - assert len(optimizer_checkpoints) == 1 - - -if __name__ == '__main__': - test_global_model() - test_global_model_shard() - test_global_optimizer() - test_global_optimizer_shard() - test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py deleted file mode 100644 index 2949c9f07..000000000 --- a/tests/test_utils/test_checkpoint_io/test_load.py +++ /dev/null @@ -1,186 +0,0 @@ -from copy import deepcopy -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.nn import Module -from torch.optim import Adam, Optimizer - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.io import load, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_groups: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.rand_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): - with torch.no_grad(): - for p in model.parameters(): - p.fill_(scalar) - for state in optimizer.state.values(): - for v in state.values(): - if isinstance(v, Tensor): - v.fill_(scalar) - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) -def test_save_global_load_global(max_shard_size_gb: float): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) - new_model, new_optimizer = prepare_model_optim() - load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) - check_model_state_dict(model.state_dict(), new_model.state_dict()) - check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def launch_dist(fn, world_size: int): - spawn(run_dist, world_size, test_fn=fn) - - -def save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - reset_model_optim(model, optimizer) - world_size = dist.get_world_size() - rank = dist.get_rank() - save(dir_name, model, optimizer, dist_meta=get_dist_metas(world_size, zero)[rank]) - - -def load_and_check_dist(dir_name: str): - world_size = dist.get_world_size() - model, optimizer = prepare_model_optim(shard=True) - reset_model_optim(model, optimizer) - model_state_dict = deepcopy(model.state_dict()) - optimizer_state_dict = deepcopy(optimizer.state_dict()) - reset_model_optim(model, optimizer, 1) - load(dir_name, model, optimizer, get_redist_meta(world_size), get_dist_metas(world_size)) - check_model_state_dict(model_state_dict, model.state_dict()) - check_optim_state_dict(optimizer_state_dict, optimizer.state_dict()) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_global_load_dist(): - model, optimizer = prepare_model_optim() - reset_model_optim(model, optimizer) - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 4) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist_load_dist(): - with TemporaryDirectory() as dir_name: - # save tp + dp - fn = partial(save_dist, dir_name, False) - launch_dist(fn, 2) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - with TemporaryDirectory() as dir_name: - # save tp + zero - fn = partial(save_dist, dir_name, True) - launch_dist(fn, 4) - # load tp + dp - fn = partial(load_and_check_dist, dir_name) - launch_dist(fn, 2) - launch_dist(fn, 4) - - -if __name__ == '__main__': - test_save_global_load_global(80 / 1024**3) - test_save_global_load_global(0) - test_save_global_load_dist() - test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py deleted file mode 100644 index 07d4597f8..000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge.py +++ /dev/null @@ -1,126 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import merge, save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_merge_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 0 - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - dp_world_size = dist.get_world_size() // 2 - if not zero: - dist_metas = { - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - } - else: - dist_metas = { - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - } - save(dir_name, model, optimizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_merge_tp_dp(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - merge(dir_name, output_dir) - assert len(os.listdir(output_dir)) == 5 - global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 - meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) - assert meta['dist_meta'] is None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 20 - optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 - - -if __name__ == '__main__': - test_merge_global() - test_merge_tp_dp(False) - test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py deleted file mode 100644 index 5da2ae4fe..000000000 --- a/tests/test_utils/test_checkpoint_io/test_merge_param.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamDistMeta -from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param - - -def test_unflatten_zero_param_even() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_unflatten_zero_param_uneven() -> None: - dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - unflattened_tensor = unflatten_zero_param(tensors, dist_metas) - assert torch.equal(orig_tensor, unflattened_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_row() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_1d_col() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_gather_tp_param_2d_reverse() -> None: - dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - gathered_tensor = gather_tp_param(tensors, dist_metas) - assert torch.equal(orig_tensor, gathered_tensor) - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_hybrid() -> None: - dist_metas = [ - ParamDistMeta(i % 2, - 2, - i // 2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_numel=4, - zero_orig_shape=[2, 2]) for i in range(12) - ] - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - merged_tensor = merge_param(tensors, dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -def test_merge_param_dummy() -> None: - dist_metas = [ParamDistMeta(0, 1, 0, 1)] - orig_tensor = torch.rand(4, 6) - merged_tensor = merge_param([orig_tensor], dist_metas) - assert torch.equal(orig_tensor, merged_tensor) - - -if __name__ == '__main__': - test_unflatten_zero_param_even() - test_unflatten_zero_param_uneven() - test_gather_tp_param_1d_row() - test_gather_tp_param_1d_col() - test_gather_tp_param_2d() - test_gather_tp_param_2d_reverse() - test_merge_param_hybrid() - test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py deleted file mode 100644 index fdc849a5e..000000000 --- a/tests/test_utils/test_checkpoint_io/test_redist.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME -from colossalai.utils.checkpoint_io.io import redist, save -from colossalai.utils.checkpoint_io.meta import ( - ParamDistMeta, - ParamRedistMeta, - PipelineRedistMeta, - RankRedistMeta, - RedistMeta, -) - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(shard: bool = False, zero: bool = False): - model = DummyModel() - if shard: - model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] - if zero: - dp_rank = dist.get_rank() // 2 - model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] - if dp_rank != 0: - model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def get_dist_metas(nprocs: int, zero: bool = False): - dp_world_size = nprocs // 2 - dist_metas = [] - for rank in range(nprocs): - if zero: - dist_metas.append({ - 'fc.weight': - ParamDistMeta(rank // 2, - dp_world_size, - rank % 2, - 2, - tp_shard_dims=[1], - tp_num_parts=[2], - zero_numel=10, - zero_orig_shape=[1, 10]), - 'fc.bias': - ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) - }) - else: - dist_metas.append({ - 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) - }) - return dist_metas - - -def get_redist_meta(nprocs: int): - dp_world_size = nprocs // 2 - rank_meta = { - 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, - 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} - } - param_meta = { - 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), - 'fc.bias': ParamRedistMeta(dp_world_size, 1) - } - return RedistMeta(rank_meta, [], param_meta) - - -def check_checkpoint_shape(dir_name: str): - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - for meta_name in global_meta['meta']: - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta['dist_meta'] is not None - assert len(meta['params']) == 2 - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - assert model_state_dict['fc.weight'].size(1) == 10 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict - assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 - assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 - - -def test_global_to_dist(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - check_checkpoint_shape(output_dir) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={'parallel': { - 'tensor': { - 'mode': '1d', - 'size': 2 - } - }}, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - test_fn() - - -def run_save_dist(dir_name: str, zero: bool): - model, optimizer = prepare_model_optim(shard=True, zero=zero) - rank = dist.get_rank() - save(dir_name, model, optimizer, dist_meta=get_dist_metas(4, zero)[rank]) - - -@pytest.mark.dist -@pytest.mark.parametrize("zero", [False, True]) -@rerun_if_address_is_in_use() -def test_dist_to_dist(zero: bool): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name, zero) - world_size = 4 - spawn(run_dist, world_size, test_fn=fn) - with TemporaryDirectory() as output_dir: - redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) - if not zero: - assert len(os.listdir(output_dir)) == 0 - else: - check_checkpoint_shape(output_dir) - - -if __name__ == '__main__': - test_global_to_dist() - test_dist_to_dist(False) - test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py deleted file mode 100644 index 2abdd95a6..000000000 --- a/tests/test_utils/test_checkpoint_io/test_save.py +++ /dev/null @@ -1,149 +0,0 @@ -import os -from functools import partial -from tempfile import TemporaryDirectory -from typing import Dict - -import pytest -import torch -import torch.distributed as dist -import torch.nn as nn -from torch import Tensor -from torch.optim import Adam - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.checkpoint_io.constant import ( - GLOBAL_META_FILE_NAME, - META_CKPT_FILE_NAME, - MODEL_CKPT_FILE_NAME, - OTHER_CKPT_FILE_NAME, -) -from colossalai.utils.checkpoint_io.io import save -from colossalai.utils.checkpoint_io.meta import ParamDistMeta - - -def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: - assert set(a.keys()) == set(b.keys()) - for k, v in a.items(): - assert torch.equal(v, b[k]) - - -def check_optim_state_dict(a: dict, b: dict, ignore_param_groups: bool = False) -> None: - assert set(a['state'].keys()) == set(b['state'].keys()) - for k, state in a['state'].items(): - b_state = b['state'][k] - for v1, v2 in zip(state.values(), b_state.values()): - if isinstance(v1, Tensor): - assert torch.equal(v1, v2) - else: - assert v1 == v2 - if not ignore_param_groups: - assert a['param_groups'] == b['param_groups'] - - -class DummyModel(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Linear(20, 1) - - -def prepare_model_optim(): - model = DummyModel() - for p in model.parameters(): - p.grad = torch.ones_like(p) - optimizer = Adam(model.parameters(), lr=1e-3) - optimizer.step() - return model, optimizer - - -def test_overwrite(): - model = DummyModel() - with TemporaryDirectory() as dir_name: - with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: - pass - with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): - save(dir_name, model) - - -def test_save_global(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer) - assert len(os.listdir(dir_name)) == 5 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 1 - assert len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - check_model_state_dict(model.state_dict(), model_state_dict) - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) - other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) - assert len(other_state_dict) == 0 - - -def test_save_global_shard(): - model, optimizer = prepare_model_optim() - with TemporaryDirectory() as dir_name: - save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) - assert len(os.listdir(dir_name)) == 7 - meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) - assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 - model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] - assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 - check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) - optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] - assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 - assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] - check_optim_state_dict( - optimizer.state_dict(), { - 'state': { - **optimizer_state_dicts[0]['state'], - **optimizer_state_dicts[1]['state'] - }, - 'param_groups': optimizer_state_dicts[0]['param_groups'] - }) - - -def run_dist(rank, world_size, port, test_fn): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - test_fn() - - -def run_save_dist(dir_name): - model, optimizer = prepare_model_optim() - dist_metas = { - 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), - 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) - } - save(dir_name, model, optimizer, dist_meta=dist_metas) - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_save_dist(): - with TemporaryDirectory() as dir_name: - fn = partial(run_save_dist, dir_name) - world_size = 2 - spawn(run_dist, world_size, test_fn=fn) - assert len(os.listdir(dir_name)) == 8 - global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) - assert len(global_meta['meta']) == 2 - for rank, meta_name in enumerate(global_meta['meta']): - meta = torch.load(os.path.join(dir_name, meta_name)) - assert meta.get('dist_meta', None) is not None - assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 - model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) - assert len(model_state_dict) == 2 - optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) - assert len(optimizer_state_dict['state']) == 2 - assert 'param_groups' in optimizer_state_dict - - -if __name__ == '__main__': - test_overwrite() - test_save_global() - test_save_global_shard() - test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py deleted file mode 100644 index 8b83caa12..000000000 --- a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py +++ /dev/null @@ -1,137 +0,0 @@ -import torch -from colossalai.utils.checkpoint_io.meta import ParamRedistMeta -from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param - - -def test_flatten_zero_param_even() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).chunk(4)) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_flatten_zero_param_uneven() -> None: - redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) - orig_tensor = torch.rand(4, 4) - tensors = list(orig_tensor.reshape(-1).split([13, 3])) - flat_tensors = flatten_zero_param(orig_tensor, redist_meta) - assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 - flat_tensors = flat_tensors[1:-1] - assert len(tensors) == len(flat_tensors) - for t, st in zip(tensors, flat_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 - unmerged_tensors = unmerged_tensors[0] - assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 - unmerged_tensors = unmerged_tensors[1:-1] - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert torch.equal(t, tl) - - -def test_split_tp_param_1d_row() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_1d_col() -> None: - redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) - orig_tensor = torch.rand(4, 4) - tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_split_tp_param_2d_reverse() -> None: - redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) - orig_tensor = torch.rand(4, 6) - tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] - split_tensors = split_tp_param(orig_tensor, redist_meta) - assert len(tensors) == len(split_tensors) - for t, st in zip(tensors, split_tensors): - assert torch.equal(t, st) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(tensors) == len(unmerged_tensors) - for t, tl in zip(tensors, unmerged_tensors): - assert len(tl) == 1 - assert torch.equal(t, tl[0]) - - -def test_unmerge_param_hybrid() -> None: - redist_meta = ParamRedistMeta(2, - 6, - tp_shard_dims=[1, 0], - tp_num_parts=[3, 2], - zero_start_dp_rank=0, - zero_offsets=[0, 1]) - orig_tensor = torch.rand(4, 6) - tensors = [ - chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) - for chunk in t.contiguous().reshape(-1).split([1, 3]) - ] - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 - for tp_rank in range(6): - for dp_rank in range(2): - assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) - - -def test_unmerge_param_dummy() -> None: - redist_meta = ParamRedistMeta(1, 1) - orig_tensor = torch.rand(4, 6) - unmerged_tensors = unmerge_param(orig_tensor, redist_meta) - assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 - assert torch.equal(orig_tensor, unmerged_tensors[0][0]) - - -if __name__ == '__main__': - test_flatten_zero_param_even() - test_flatten_zero_param_uneven() - test_split_tp_param_1d_row() - test_split_tp_param_1d_col() - test_split_tp_param_2d() - test_split_tp_param_2d_reverse() - test_unmerge_param_hybrid() - test_unmerge_param_dummy() diff --git a/tests/test_zero/test_legacy/common.py b/tests/test_zero/test_legacy/common.py deleted file mode 100644 index 2c3d122c7..000000000 --- a/tests/test_zero/test_legacy/common.py +++ /dev/null @@ -1,140 +0,0 @@ -from functools import partial - -import torch -import torch.distributed as dist - -from colossalai.logging import get_dist_logger -from colossalai.utils import checkpoint -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 - -LOGGER = get_dist_logger('zero_test') - -MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(size=1), tensor=dict(size=2, mode=None))) - -_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25, - fp32_reduce_scatter=False, - tensor_placement_policy='cuda', - gradient_predivide_factor=1.0, - shard_strategy=TensorShardStrategy(), - reuse_fp16_shard=False) - -_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5, - min_scale=1, - growth_factor=2, - backoff_factor=0.5, - growth_interval=1000, - hysteresis=2, - max_scale=2**32) - -ZERO_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), - zero=dict( - model_config=_ZERO_MODEL_CONFIG, - optimizer_config=_ZERO_OPTIMIZER_CONFIG, - ), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - -CONFIG = dict(fp16=dict(mode=None,), - zero=dict(level=3, - verbose=False, - offload_optimizer_config=dict(device='cpu', pin_memory=True, buffer_count=5, fast_init=False), - offload_param_config=dict(device='cpu', - pin_memory=True, - buffer_count=5, - buffer_size=1e8, - max_in_cpu=1e9)), - parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None))) - - -def run_fwd_bwd(model, data, label, criterion, enable_autocast=False): - model.train() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - loss = loss.float() - if isinstance(model, ShardedModelV2): - model.backward(loss) - else: - loss.backward() - - -def checkpoint_wrapper(module, enable=True): - if enable: - module.forward = partial(checkpoint, module.forward) - return module - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-2, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def check_grads(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_grad = zero_p.grad.clone().to(p.device) - grad = p.grad.float() - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose) - - -def check_params(model, zero_model, loose=False): - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - # assert p.dtype == zero_p.dtype - assert allclose(p.float(), zero_p.float(), loose=loose), f"diff {p.float() - zero_p.float()}" - - -def check_grads_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - # zero_grad = zero_p.grad.clone().to(p.device) - if zero_p.colo_attr.is_replicated: - zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device) - chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - grad = chunks[rank].float() - if zero_grad.size(0) > grad.size(0): - zero_grad = zero_grad[:grad.size(0)] - else: - zero_grad = zero_p.colo_attr.grad_payload - grad = p.grad.to(zero_grad.dtype) - - assert grad.dtype == zero_grad.dtype - assert allclose(grad, zero_grad, loose=loose), f'diff: {grad - zero_grad}' - - -def check_params_padding(model, zero_model, loose=False): - rank = dist.get_rank() - for p, zero_p in zip(model.parameters(), zero_model.parameters()): - zero_p = zero_p.clone().to(p.device) - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank] - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - assert p.dtype == zero_p.dtype - assert allclose(p, zero_p, loose=loose) - - -def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=False): - rank = dist.get_rank() - for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()): - if zero_p.colo_attr.param_is_sharded: - zero_p = zero_p.colo_attr.data_payload.to(p.device).float() - chunks = torch.flatten(p).chunk(dist.get_world_size()) - if rank >= len(chunks): - continue - p = chunks[rank].float() - if zero_p.size(0) > p.size(0): - zero_p = zero_p[:p.size(0)] - else: - zero_p = zero_p.colo_attr.data_payload.to(p.device) - - assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype) - assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}' diff --git a/tests/test_zero/test_legacy/test_found_inf.py b/tests/test_zero/test_legacy/test_found_inf.py deleted file mode 100644 index e90158e0a..000000000 --- a/tests/test_zero/test_legacy/test_found_inf.py +++ /dev/null @@ -1,67 +0,0 @@ -import pytest -import torch -from common import CONFIG -from test_sharded_optim_v2 import _run_step - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("cpu_offload", [True, False]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers'] - shard_strategy = shard_strategy_class() - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'cuda', - reuse_fp16_shard=True, - ) - - sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - for i, (data, label) in enumerate(train_dataloader): - if i > 1: - break - assert zero_model.overflow_counter == 0 - data, label = data.cuda(), label.cuda() - _run_step(zero_model, sharded_optim, data, label, criterion, False) - for param in zero_model.parameters(): - assert not has_inf_or_nan(param.colo_attr.data_payload) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_found_inf() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_found_inf(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_found_inf(world_size=2) diff --git a/tests/test_zero/test_legacy/test_gemini_manager.py b/tests/test_zero/test_legacy/test_gemini_manager.py deleted file mode 100644 index 0e956f7cc..000000000 --- a/tests/test_zero/test_legacy/test_gemini_manager.py +++ /dev/null @@ -1,75 +0,0 @@ -import pytest -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState - - -@pytest.mark.dist -@clear_cache_before_run() -def test_gemini_manager(): - # reset the manager, in case that there exists memory information left - manager = StatefulTensor.GST_MGR - manager.reset() - - # occupation 8 - st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda')) - # occupation 60 - st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu')) - - # occupation 28 - t1 = torch.empty(7, device='cuda') - # occupation 12 - t2 = torch.empty(3, device='cpu') - st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD) - st4 = StatefulTensor(None, TensorState.FREE) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 60 - assert manager.total_mem['cuda'] == 36 - assert manager.state_mem['cpu'][TensorState.HOLD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28 - - st4.payload_reset(t2) - st3.payload_reset(t2) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 84 - assert manager.total_mem['cuda'] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD] == 72 - assert manager.state_mem['cuda'][TensorState.HOLD] == 8 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0 - - st1.move_to(torch.device('cpu')) - st2.move_to(torch.device('cpu')) - st3.move_to(torch.device('cuda', 0)) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 80 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - - st1.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.COMPUTE) - st2.trans_state(TensorState.HOLD_AFTER_BWD) - - assert manager.total_number == 4 - assert manager.total_mem['cpu'] == 80 - assert manager.total_mem['cuda'] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD] == 12 - assert manager.state_mem['cuda'][TensorState.HOLD] == 0 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12 - assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60 - assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0 - assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8 - assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0 - - -if __name__ == '__main__': - test_gemini_manager() diff --git a/tests/test_zero/test_legacy/test_init_context.py b/tests/test_zero/test_legacy/test_init_context.py deleted file mode 100644 index 844938271..000000000 --- a/tests/test_zero/test_legacy/test_init_context.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.logging import get_dist_logger -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.utils.memory import colo_device_memory_used -from colossalai.zero.gemini.memory_tracer.utils import colo_model_mem_usage -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("init_device_type", ['cpu', 'cuda']) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(init_device_type, shard_strategy_class): - logger = get_dist_logger("test_zero_init") - - for name, get_components_func in non_distributed_component_funcs._registry.items(): - # because the ZeroInitContext automatically turns parameters to fp16 - # and the beit model use tensor.erfinv_() function to initialize weights - # tensor.erfinv_() doesn't support Half in CPU, we omit the beit model - if name == 'beit': - continue - model_builder, _, _, _, _ = get_components_func() - if init_device_type == 'cuda': - init_device = get_current_device() - elif init_device_type == 'cpu': - init_device = torch.device("cpu") - else: - continue - - model_numel_tensor = torch.zeros(1, dtype=torch.int) - with ZeroInitContext(target_device=init_device, - shard_strategy=shard_strategy_class(), - shard_param=True, - model_numel_tensor=model_numel_tensor): - model = model_builder(checkpoint=True) - - for param in model.parameters(): - assert hasattr(param, 'colo_attr') - assert param.colo_attr.sharded_data_tensor.dtype == torch.half - assert param.colo_attr.sharded_data_tensor.is_sharded - assert param.colo_attr.data_payload.device.type == init_device.type, \ - f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}' - - cuda_mem_use, _ = colo_model_mem_usage(model) - model_data_cuda_mem_MB = cuda_mem_use / 1e6 - logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) - sys_cuda_mem_MB = colo_device_memory_used(get_current_device()) / 1e6 - logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) - logger.info(f"Model Number Parameter {model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 4]) -@rerun_if_address_is_in_use() -def test_zero_init_context(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_init_context(1) diff --git a/tests/test_zero/test_legacy/test_param_op.py b/tests/test_zero/test_legacy/test_param_op.py deleted file mode 100644 index b91371b98..000000000 --- a/tests/test_zero/test_legacy/test_param_op.py +++ /dev/null @@ -1,82 +0,0 @@ -import copy - -import torch - -from colossalai.testing import clear_cache_before_run -from colossalai.zero.legacy.gemini.paramhooks import BaseParamHookMgr -from tests.components_to_test.registry import non_distributed_component_funcs - - -def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: - if loose: - return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) - return torch.allclose(tensor_a, tensor_b) - - -def run_model(model, inputs, label, criterion, use_param_hook=False): - if use_param_hook: - - class HooKWrapper: - - def __init__(self) -> None: - self.hook_triggered_times = 0 - - def wrapper_func(self): - - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad - - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - if criterion: - y = model(inputs) - loss = criterion(y, label) - else: - loss = model(inputs, label) - loss = loss.float() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - -@clear_cache_before_run() -def test_base_param_hook(): - test_models = ['repeated_computed_layers', 'resnet18', 'hanging_param_model', 'inline_op_model'] - # test_models = ['bert'] - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - torch.manual_seed(0) - model = model_builder(checkpoint=True).cuda() - model.train() - - for i, (inputs, label) in enumerate(train_dataloader): - if i > 0: - break - model_copy = copy.deepcopy(model) - - run_model(model, inputs.cuda(), label.cuda(), criterion, False) - ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" - - -if __name__ == '__main__': - test_base_param_hook() diff --git a/tests/test_zero/test_legacy/test_shard_model_v2.py b/tests/test_zero/test_legacy/test_shard_model_v2.py deleted file mode 100644 index 93d624aa2..000000000 --- a/tests/test_zero/test_legacy/test_shard_model_v2.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -from common import CONFIG, check_grads_padding, run_fwd_bwd -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("enable_autocast", [True]) -@parameterize("shard_strategy_class", [BucketTensorShardStrategy]) -def run_model_test(enable_autocast, shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, _, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - model = DDP(model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - - data, label = cast_tensor_to_fp16(data).cuda(), label.cuda() - run_fwd_bwd(model, data, label, criterion, enable_autocast) - run_fwd_bwd(zero_model, data, label, criterion, enable_autocast) - - check_grads_padding(model, zero_model, loose=True) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_model_test() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_model_v2(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_shard_model_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_shard_param.py b/tests/test_zero/test_legacy/test_shard_param.py deleted file mode 100644 index 4ba43edce..000000000 --- a/tests/test_zero/test_legacy/test_shard_param.py +++ /dev/null @@ -1,91 +0,0 @@ -from copy import deepcopy - -import pytest -import torch -from common import CONFIG, allclose - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_param import ShardedTensor -from colossalai.zero.legacy.sharded_param.sharded_param import ShardedParamV2 - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_shard_tensor_with_strategy(shard_strategy_class, world_size): - t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) - assert list(t.origin_shape) == [world_size * 2, 3] - assert list(t.shape) == [world_size * 2, 3] - - shard_strategy = shard_strategy_class() - - # test shard strategy - shard_strategy.shard([t]) - assert list(t.shape) == [6], f"{list(t.shape)} vs 6" - shard_strategy.gather([t]) - assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}" - - -def _run_shard_tensor(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_shard_tensor_with_strategy(world_size=world_size) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_tensor(world_size): - spawn(_run_shard_tensor, world_size) - - -def _run_shard_param_v2(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - param = torch.nn.Parameter(torch.randn(2, 3)) - param_ref = deepcopy(param) - sparam = ShardedParamV2(param=param) - - allclose(sparam.data_payload, param_ref.data) - - # Test get memory usage - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" - - sparam.set_data_none() - assert (param.data.numel() == 0) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - # 4 is size of dummy tensor of param.data - assert cpu_mem_use == 2 * 3 * 4 * 2 - - sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) - sparam.set_data_none() - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - # append a grad to torch param - param.data = sparam.data_payload - param.grad = torch.randn(2, 3) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}" - assert cuda_mem_use == 0 - - # reuse torch grad for sparam - sparam.saved_grad = StatefulTensor(param.grad) - cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() - assert cpu_mem_use == 2 * 3 * 4 * 2 - assert cuda_mem_use == 0 - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_shard_param_v2(world_size): - spawn(_run_shard_param_v2, world_size) - - -if __name__ == '__main__': - # test_shard_tensor(2) - test_shard_param_v2(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py b/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py deleted file mode 100644 index 1ca144662..000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_state_dict.py +++ /dev/null @@ -1,89 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.nn.optimizer import HybridAdam -from colossalai.tensor import ProcessGroup -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from tests.components_to_test.registry import non_distributed_component_funcs -from tests.test_tensor.common_utils import set_seed - - -def init_zero(model_builder, placement_policy): - device = get_current_device() if placement_policy == 'cuda' else torch.device('cpu') - shard_strategy = TensorShardStrategy() - with ZeroInitContext(target_device=device, shard_strategy=shard_strategy, shard_param=True): - model = model_builder() - model = ShardedModelV2( - model, - shard_strategy, - tensor_placement_policy=placement_policy, - reuse_fp16_shard=True, - ) - optim = HybridAdam(model.parameters(), lr=1e-3) - optim = ShardedOptimizerV2(model, optim, initial_scale=32) - return model, optim - - -def run_step(model, optim, criterion, data, label): - optim.zero_grad() - logits = model(data) - loss = criterion(logits, label) - optim.backward(loss) - optim.step() - - -def check_state_dict_eq(state_dict, other): - for p, state in state_dict['state'].items(): - other_state = other['state'][p] - for k, v in state.items(): - if isinstance(v, torch.Tensor): - assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}' - else: - assert v == other_state[k] - - -@parameterize('placement_policy', ['cuda', 'cpu']) -def run_nested_model(placement_policy): - get_components_func = non_distributed_component_funcs.get_callable('simple_net') - model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() - - set_seed(42) - model, optim = init_zero(model_builder, placement_policy) - set_seed(42) - model_copy, optim_copy = init_zero(model_builder, placement_policy) - - model.train() - model_copy.train() - pg = ProcessGroup() - set_seed(pg.dp_local_rank()) - data_iter = iter(train_dataloader) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model, optim, criterion, data, label) - optim_copy.load_state_dict(optim.state_dict()) - check_state_dict_eq(optim.state_dict(), optim_copy.state_dict()) - - data, label = map(lambda x: x.cuda(), next(data_iter)) - run_step(model_copy, optim_copy, criterion, data, label) - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_nested_model() - - -@pytest.mark.dist -@pytest.mark.parametrize('world_size', [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_state_dist(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_state_dist(2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_v2.py b/tests/test_zero/test_legacy/test_sharded_optim_v2.py deleted file mode 100644 index c6f77995e..000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_v2.py +++ /dev/null @@ -1,110 +0,0 @@ -import pytest -import torch -import torch.distributed as dist -from common import CONFIG, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.amp import convert_to_apex_amp -from colossalai.nn.optimizer import CPUAdam -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.legacy.sharded_optim import ShardedOptimizerV2 -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def _run_step(model, optimizer, data, label, criterion, enable_autocast=False): - model.train() - optimizer.zero_grad() - with torch.cuda.amp.autocast(enabled=enable_autocast): - if criterion: - y = model(data) - loss = criterion(y, label) - else: - loss = model(data, label) - - loss = loss.float() - if isinstance(model, ShardedModelV2): - optimizer.backward(loss) - else: - loss.backward() - optimizer.step() - - -@parameterize("cpu_offload", [True, False]) -@parameterize("use_cpuadam", [True, False]) -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -@parameterize("gpu_margin_mem_ratio", [0.0, 0.7]) -def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, gpu_margin_mem_ratio): - test_models = ['repeated_computed_layers', 'resnet18', 'bert', 'hanging_param_model'] - shard_strategy = shard_strategy_class() - - if use_cpuadam and cpu_offload is False: - return - if gpu_margin_mem_ratio > 0.0 and not (cpu_offload and use_cpuadam): - return - - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device(f'cpu:0') if cpu_offload else get_current_device(), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2( - zero_model, - shard_strategy, - tensor_placement_policy='cpu' if cpu_offload else 'auto', - reuse_fp16_shard=use_cpuadam, - ) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda().float() - - if use_cpuadam: - optimizer_class = CPUAdam - optim = optimizer_class(model.parameters(), lr=1e-3) - sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3) - sharded_optim = ShardedOptimizerV2(zero_model, - sharded_optim, - initial_scale=2**5, - gpu_margin_mem_ratio=gpu_margin_mem_ratio) - - amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False) - apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config) - if dist.get_world_size() > 1: - apex_model = DDP(apex_model, device_ids=[torch.cuda.current_device()]) - - for i, (data, label) in enumerate(train_dataloader): - if i > 5: - break - data, label = data.cuda(), label.cuda() - _run_step(apex_model, apex_optimizer, data, label, criterion, False) - _run_step(zero_model, sharded_optim, data, label, criterion, False) - check_sharded_model_params(model, zero_model, loose=True, reuse_fp16_shard=use_cpuadam) - for param in model.parameters(): - assert not has_inf_or_nan(param) - - -def _run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - _run_test_sharded_optim_v2() - - -# use_cpuadam = True can be used with cpu_offload = False -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_sharded_optim_v2(world_size): - spawn(_run_dist, world_size) - - -if __name__ == '__main__': - test_sharded_optim_v2(world_size=2) diff --git a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py b/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py deleted file mode 100644 index 0223f18c2..000000000 --- a/tests/test_zero/test_legacy/test_sharded_optim_with_sync_bn.py +++ /dev/null @@ -1,87 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from torchvision.models import resnet50 - -import colossalai -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import TensorShardStrategy - - -def run_dist(rank, world_size, port): - # this test only runs on resnet18 - # as this model has sync batch normalization - # need to configure cudnn deterministic so that - # randomness of convolution layers will be disabled - zero_config = dict(model_config=dict(shard_strategy=TensorShardStrategy())) - colossalai.launch(config=dict(zero=zero_config, cudnn_deterministic=True, cudnn_benchmark=False), - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): - model = resnet50() - optimizer = torch.optim.Adam(model.parameters(), lr=0.001) - criterion = torch.nn.CrossEntropyLoss() - - engine, *args = colossalai.initialize(model, optimizer, criterion) - - # train for dummy iterations - engine.train() - for _ in range(2): - data = torch.rand(4, 3, 128, 128).cuda().half() - label = torch.randint(0, 10, size=(4,)).cuda() - engine.zero_grad() - out = engine(data) - loss = engine.criterion(out, label) - engine.backward(loss) - engine.step() - - # test - # need to make sure the batch norm stats are synchronized - # so that given the same input, the model will produce the same - # output on different ranks - engine.eval() - data = torch.rand(4, 3, 128, 128).cuda().half() - dist.broadcast(data, src=0, group=gpc.get_group(ParallelMode.DATA)) - - # predict - out = engine(data) - - # test if results are equal - tensor_list = [torch.empty_like(out) for _ in range(world_size - 1)] - tensor_list.insert(rank, out) - dist.all_gather(tensor_list=tensor_list, tensor=out, group=gpc.get_group(ParallelMode.DATA)) - - assert torch.all(tensor_list[0] == tensor_list[1]), \ - 'expected the output from different ranks to be the same, but got different values' - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_sharded_optim_with_sync_bn(): - """ - This test is to make sure that buffers are synchronized between ranks - when using ZeRO. An example of module buffer is the running stats of - BatchNormalization layer, i.e. mean and var. - - If the buffers are not synchronized, the model will produce different - output even though the input and parameters are the same. This is not - wanted if we are doing predictions. - - """ - spawn(run_dist, 2) - - -if __name__ == '__main__': - test_sharded_optim_with_sync_bn() diff --git a/tests/test_zero/test_legacy/test_state_dict.py b/tests/test_zero/test_legacy/test_state_dict.py deleted file mode 100644 index 5f76fff3e..000000000 --- a/tests/test_zero/test_legacy/test_state_dict.py +++ /dev/null @@ -1,55 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -from common import CONFIG - -import colossalai -from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.shard_utils import BucketTensorShardStrategy, TensorShardStrategy -from colossalai.zero.legacy.sharded_model import ShardedModelV2 -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from tests.components_to_test.registry import non_distributed_component_funcs - - -@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_zero_state_dict(shard_strategy_class): - test_models = ['repeated_computed_layers', 'resnet18'] - shard_strategy = shard_strategy_class() - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() - - with ZeroInitContext(target_device=torch.device('cuda', torch.cuda.current_device()), - shard_strategy=shard_strategy, - shard_param=True): - zero_model = model_builder(checkpoint=True) - zero_model = ShardedModelV2(zero_model, shard_strategy) - - model = model_builder(checkpoint=True).half() - col_model_deepcopy(zero_model, model) - model = model.cuda() - - zero_state_dict = zero_model.state_dict() - for key, val in model.state_dict().items(): - assert torch.equal(val, zero_state_dict[key].to(val.device)) - - -def run_dist(rank, world_size, port): - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_zero_state_dict() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@rerun_if_address_is_in_use() -def test_zero_state_dict(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_state_dict(2) diff --git a/tests/test_zero/test_legacy/test_tensor_utils.py b/tests/test_zero/test_legacy/test_tensor_utils.py deleted file mode 100644 index 238bc3fe1..000000000 --- a/tests/test_zero/test_legacy/test_tensor_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import pytest -import torch - -import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.utils.cuda import get_current_device -from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor -from colossalai.zero.legacy.gemini.tensor_utils import ( - colo_model_data_move_to_cpu, - colo_model_data_tensor_move, - colo_model_data_tensor_move_inline, - colo_model_tensor_clone, - colo_tensor_mem_usage, -) - - -def _run_colo_tensor_mem_usage(): - for i in range(1): - if i == 1: - t1 = StatefulTensor(torch.randn(2, 2)) - t2 = StatefulTensor(torch.randn(4, 4)) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - else: - t1 = torch.randn(2, 2) - t2 = torch.randn(4, 4) - c1, g1 = colo_tensor_mem_usage(t1) - c2, g2 = colo_tensor_mem_usage(t2) - assert c1 * 4 == c2 - assert g1 * 4 == g2 - - -def _run_colo_model_data_tensor_move_inline(): - for t in [StatefulTensor(torch.randn(2, 3)), torch.randn(2, 3)]: - colo_model_data_tensor_move_inline(t, get_current_device()) - assert t.device == get_current_device() - - -def _run_colo_model_data_tensor_move(): - for t in [(StatefulTensor(torch.ones(2, 3)), StatefulTensor(torch.zeros(2, 3).to(get_current_device()))), - (torch.ones(2, 3), torch.zeros(2, 3).to(get_current_device()))]: - cpu_t, cuda_t = t - colo_model_data_tensor_move(cpu_t, cuda_t) - assert cuda_t.device == get_current_device() - - -def _run_colo_model_data_move_to_cpu(): - for t in [StatefulTensor(torch.randn(2, 2)), torch.randn(4, 4)]: - colo_model_data_move_to_cpu(t) - assert t.device == torch.device("cpu") - - -def _run_colo_model_tensor_clone(): - for t in [ - StatefulTensor(torch.randn(2, 2).cuda(torch.cuda.current_device())), - torch.randn(4, 4).cuda(torch.cuda.current_device()) - ]: - if issubclass(type(t), StatefulTensor): - assert t.payload.device == get_current_device() - else: - assert t.device == get_current_device() - p = colo_model_tensor_clone(t, get_current_device()) - assert p.device == get_current_device() - for i in range(2): - for j in range(2): - if issubclass(type(t), StatefulTensor): - assert t.payload.device == p.device - assert t.payload[i][j] == p[i][j] - else: - assert t.device == p.device - assert t[i][j] == p[i][j] - - -def run_dist(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - _run_colo_tensor_mem_usage() - _run_colo_model_data_tensor_move_inline() - _run_colo_model_data_tensor_move() - _run_colo_model_data_move_to_cpu() - _run_colo_model_tensor_clone() - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_zero_tensor_utils(world_size): - spawn(run_dist, world_size) - - -if __name__ == '__main__': - test_zero_tensor_utils(world_size=2) diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py deleted file mode 100644 index 826a543db..000000000 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ /dev/null @@ -1,113 +0,0 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -import pytest -import torch -import torch.distributed as dist -from common import MP_PARALLEL_CONFIG, ZERO_PARALLEL_CONFIG, check_params, check_sharded_model_params -from torch.nn.parallel import DistributedDataParallel as DDP - -import colossalai -from colossalai.core import global_context as gpc -from colossalai.testing import rerun_if_address_is_in_use, spawn -from colossalai.zero.legacy.init_ctx import ZeroInitContext -from colossalai.zero.legacy.sharded_model.utils import col_model_deepcopy -from colossalai.zero.low_level._utils import has_inf_or_nan -from tests.components_to_test.registry import non_distributed_component_funcs - - -def run_dist(rank, world_size, port, parallel_config, bf16): - is_mp_config = parallel_config == MP_PARALLEL_CONFIG - is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG - if bf16: - parallel_config['zero']['model_config']['bf16'] = True - colossalai.launch(config=parallel_config, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - - test_models = ['repeated_computed_layers', 'resnet18', 'bert'] - for model_name in test_models: - get_components_func = non_distributed_component_funcs.get_callable(model_name) - model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() - with ZeroInitContext(target_device=torch.cuda.current_device(), - shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True, - bf16=bf16): - colo_model = model_builder(checkpoint=True) - - colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) - engine, train_dataloader, _, _ = colossalai.initialize(colo_model, - optimizer=colo_optimizer, - criterion=criterion, - train_dataloader=train_dataloader) - dtype = torch.bfloat16 if bf16 else torch.float16 - torch_model = model_builder(checkpoint=True).to(dtype) - col_model_deepcopy(engine.model, torch_model) - torch_model = torch_model.cuda().float() - - engine.train() - torch_optimizer = optimizer_class(torch_model.parameters(), lr=1e-3) - - if dist.get_world_size() > 1: - torch_model = DDP(torch_model, device_ids=[torch.cuda.current_device()]) - - i = 0 - for data, label in train_dataloader: - if i > 4: - break - - data, label = data.cuda(), label.cuda() - - engine.zero_grad() - torch_optimizer.zero_grad() - - if criterion: - output = engine(data) - loss = engine.criterion(output, label) - - torch_output = torch_model(data) - torch_loss = engine.criterion(torch_output, label) - else: - loss = engine(data, label) - torch_loss = torch_model(data, label) - - engine.backward(loss) - engine.step() - - torch_loss.backward() - - for param in torch_model.parameters(): - if param.grad is not None: - assert not has_inf_or_nan(param.grad) - - torch_optimizer.step() - i += 1 - - if is_mp_config: - check_params(torch_model, colo_model, loose=True) - elif is_zero_config: - check_sharded_model_params(torch_model, colo_model, loose=True) - - -# FIXME: enable this test in next PR -@pytest.mark.skip -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [2, 4]) -@rerun_if_address_is_in_use() -def test_mp_engine(world_size): - spawn(run_dist, world_size, parallel_config=MP_PARALLEL_CONFIG) - - -@pytest.mark.dist -@pytest.mark.parametrize("world_size", [1, 2]) -@pytest.mark.parametrize("bf16", [True, False]) -@rerun_if_address_is_in_use() -def test_zero_engine(world_size, bf16): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16) - - -if __name__ == '__main__': - test_zero_engine(world_size=4)