From 99870726b140d775713c5855a8819a0fe05f0ff9 Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 8 Nov 2022 15:15:13 +0800 Subject: [PATCH] [CheckpointIO] a uniform checkpoint I/O module (#1689) --- colossalai/utils/checkpoint_io/__init__.py | 2 + colossalai/utils/checkpoint_io/backend.py | 74 ++++++ colossalai/utils/checkpoint_io/constant.py | 9 + colossalai/utils/checkpoint_io/convertor.py | 227 ++++++++++++++++++ colossalai/utils/checkpoint_io/distributed.py | 127 ++++++++++ colossalai/utils/checkpoint_io/io.py | 170 +++++++++++++ colossalai/utils/checkpoint_io/meta.py | 81 +++++++ colossalai/utils/checkpoint_io/reader.py | 131 ++++++++++ colossalai/utils/checkpoint_io/utils.py | 223 +++++++++++++++++ colossalai/utils/checkpoint_io/writer.py | 98 ++++++++ .../test_build_checkpoints.py | 120 +++++++++ .../test_checkpoint_io/test_load.py | 188 +++++++++++++++ .../test_checkpoint_io/test_merge.py | 127 ++++++++++ .../test_checkpoint_io/test_merge_param.py | 101 ++++++++ .../test_checkpoint_io/test_redist.py | 149 ++++++++++++ .../test_checkpoint_io/test_save.py | 147 ++++++++++++ .../test_checkpoint_io/test_unmerge_param.py | 137 +++++++++++ 17 files changed, 2111 insertions(+) create mode 100644 colossalai/utils/checkpoint_io/__init__.py create mode 100644 colossalai/utils/checkpoint_io/backend.py create mode 100644 colossalai/utils/checkpoint_io/constant.py create mode 100644 colossalai/utils/checkpoint_io/convertor.py create mode 100644 colossalai/utils/checkpoint_io/distributed.py create mode 100644 colossalai/utils/checkpoint_io/io.py create mode 100644 colossalai/utils/checkpoint_io/meta.py create mode 100644 colossalai/utils/checkpoint_io/reader.py create mode 100644 colossalai/utils/checkpoint_io/utils.py create mode 100644 colossalai/utils/checkpoint_io/writer.py create mode 100644 tests/test_utils/test_checkpoint_io/test_build_checkpoints.py create mode 100644 tests/test_utils/test_checkpoint_io/test_load.py create mode 100644 tests/test_utils/test_checkpoint_io/test_merge.py create mode 100644 tests/test_utils/test_checkpoint_io/test_merge_param.py create mode 100644 tests/test_utils/test_checkpoint_io/test_redist.py create mode 100644 tests/test_utils/test_checkpoint_io/test_save.py create mode 100644 tests/test_utils/test_checkpoint_io/test_unmerge_param.py diff --git a/colossalai/utils/checkpoint_io/__init__.py b/colossalai/utils/checkpoint_io/__init__.py new file mode 100644 index 000000000..fe0308668 --- /dev/null +++ b/colossalai/utils/checkpoint_io/__init__.py @@ -0,0 +1,2 @@ +from .io import load, merge, redist, save +from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta) diff --git a/colossalai/utils/checkpoint_io/backend.py b/colossalai/utils/checkpoint_io/backend.py new file mode 100644 index 000000000..140192c05 --- /dev/null +++ b/colossalai/utils/checkpoint_io/backend.py @@ -0,0 +1,74 @@ +import shutil +import tempfile +from abc import ABC, abstractmethod +from typing import Dict, List, Type + +from .reader import CheckpointReader, DiskCheckpointReader +from .writer import CheckpointWriter, DiskCheckpointWriter + +_backends: Dict[str, Type['CheckpointIOBackend']] = {} + + +def register(name: str): + assert name not in _backends, f'"{name}" is registered' + + def wrapper(cls): + _backends[name] = cls + return cls + + return wrapper + + +def get_backend(name: str) -> 'CheckpointIOBackend': + assert name in _backends, f'Unsupported backend "{name}"' + return _backends[name]() + + +class CheckpointIOBackend(ABC): + + def __init__(self) -> None: + super().__init__() + self.temps: List[str] = [] + + @abstractmethod + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + pass + + @abstractmethod + def get_reader(self, base_name: str) -> CheckpointReader: + pass + + @abstractmethod + def get_temp(self, base_name: str) -> str: + pass + + @abstractmethod + def clean_temp(self) -> None: + pass + + +@register('disk') +class CheckpointDiskIO(CheckpointIOBackend): + + def get_writer(self, + base_name: str, + overwrite: bool = False, + rank: int = 0, + world_size: int = 1) -> CheckpointWriter: + return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size) + + def get_reader(self, base_name: str) -> CheckpointReader: + return DiskCheckpointReader(base_name) + + def get_temp(self, base_name: str) -> str: + temp_dir_name = tempfile.mkdtemp(dir=base_name) + self.temps.append(temp_dir_name) + return temp_dir_name + + def clean_temp(self) -> None: + for temp_dir_name in self.temps: + shutil.rmtree(temp_dir_name) diff --git a/colossalai/utils/checkpoint_io/constant.py b/colossalai/utils/checkpoint_io/constant.py new file mode 100644 index 000000000..219948474 --- /dev/null +++ b/colossalai/utils/checkpoint_io/constant.py @@ -0,0 +1,9 @@ +import re + +GLOBAL_META_FILE_NAME = 'global_meta.bin' +MODEL_CKPT_FILE_NAME = 'model.bin' +OPTIM_CKPT_FILE_NAME = 'optim.bin' +META_CKPT_FILE_NAME = 'meta.bin' +OTHER_CKPT_FILE_NAME = 'other.bin' + +CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other') diff --git a/colossalai/utils/checkpoint_io/convertor.py b/colossalai/utils/checkpoint_io/convertor.py new file mode 100644 index 000000000..529ceb868 --- /dev/null +++ b/colossalai/utils/checkpoint_io/convertor.py @@ -0,0 +1,227 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional + +from torch import Tensor + +from .distributed import merge_param, unmerge_param +from .meta import ParamDistMeta, RedistMeta +from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none) + + +class CheckpointConvertor(ABC): + + @abstractmethod + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + pass + + @abstractmethod + def complete(self) -> None: + pass + + +class ModelCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int]) -> None: + super().__init__() + self.param_count = param_count + self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict) + + @abstractmethod + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + for k, tensor in state_dict.items(): + self.buffer[k][rank] = tensor + converted_keys = set() + for k, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[k]: + tensors = [] + dist_metas = [] + for rank, tensor in rank_dict.items(): + tensors.append(tensor) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][k]) + self.convert_tensors(k, tensors, dist_metas) + converted_keys.add(k) + for k in converted_keys: + del self.buffer[k] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class ModelCheckpointMerger(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None: + super().__init__(param_count) + self.sharder = ModelCheckpointSharder(max_shard_size) + self.save_fn = save_fn + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + shard = self.sharder.append(key, tensor) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class ModelCheckpointRedistor(ModelCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + redist_meta: RedistMeta) -> None: + super().__init__(param_count) + self.save_fns = save_fns + self.redist_meta = redist_meta + nprocs = len(save_fns) + self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None: + if len(dist_metas) == 0: + # already global + tensor = tensors[0] + else: + assert len(dist_metas) == len(tensors) + tensor = merge_param(tensors, dist_metas) + for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[key][tp_rank][dp_rank]: + shard = self.sharders[rank].append(key, t) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) + + +class OptimizerCheckpointConvertor(CheckpointConvertor): + + def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]], + paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__() + self.param_count = param_count + self.param_to_os = param_to_os + self.paired_os = paired_os + self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict) + self.os_to_param = {v: k for k, v in param_to_os.items()} + + @abstractmethod + def setup(self, param_groups: dict) -> None: + pass + + @abstractmethod + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + pass + + def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for rank, state_dict in shard_dict.items(): + self.setup(state_dict['param_groups']) + for idx, state in state_dict['state'].items(): + self.buffer[idx][rank] = state + converted_indices = set() + for idx, rank_dict in self.buffer.items(): + if len(rank_dict) == self.param_count[self.os_to_param[idx]]: + states = [] + dist_metas = [] + for rank, state in rank_dict.items(): + states.append(state) + if dist_meta_list[rank] is not None: + dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]]) + self.convert_states(idx, states, dist_metas) + converted_indices.add(idx) + for idx in converted_indices: + del self.buffer[idx] + + def complete(self) -> None: + assert len(self.buffer) == 0 + + +class OptimizerCheckpointMerger(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fn = save_fn + self.sharder = None + + def setup(self, param_groups: dict) -> None: + if self.sharder is None: + self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) == len(states) + new_state = {} + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas) + else: + new_state[state_key] = state_tensor + shard = self.sharder.append(idx, new_state) + run_if_not_none(self.save_fn, shard) + + def complete(self) -> None: + super().complete() + run_if_not_none(self.save_fn, self.sharder.complete()) + + +class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor): + + def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int], + param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]], + redist_meta: RedistMeta) -> None: + super().__init__(param_count, param_to_os, paired_os) + self.max_shard_size = max_shard_size + self.save_fns = save_fns + self.redist_meta = redist_meta + self.sharders: List[OptimizerCheckpointSharder] = [] + self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + for k, rank_meta in redist_meta.rank_meta.items(): + for rank, rank_info in rank_meta.items(): + self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank) + + def setup(self, param_groups: dict) -> None: + if len(self.sharders) == 0: + nprocs = len(self.save_fns) + for _ in range(nprocs): + self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups)) + + def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None: + need_merge: bool = True + if len(dist_metas) == 0: + need_merge = False + else: + assert len(dist_metas) == len(states) + new_states = [{} for _ in range(len(self.save_fns))] + for state_key, state_tensor in states[0].items(): + if self.paired_os[idx][state_key]: + if need_merge: + tensor = merge_param([state[state_key] for state in states], dist_metas) + else: + tensor = state_tensor + for tp_rank, tensor_list in enumerate( + unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])): + for dp_rank, t in enumerate(tensor_list): + for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]: + new_states[rank][state_key] = t + else: + for new_state in new_states: + new_state[state_key] = state_tensor + for rank, new_state in enumerate(new_states): + shard = self.sharders[rank].append(idx, new_state) + run_if_not_none(self.save_fns[rank], shard) + + def complete(self) -> None: + super().complete() + for rank, save_fn in enumerate(self.save_fns): + run_if_not_none(save_fn, self.sharders[rank].complete()) diff --git a/colossalai/utils/checkpoint_io/distributed.py b/colossalai/utils/checkpoint_io/distributed.py new file mode 100644 index 000000000..bf720437c --- /dev/null +++ b/colossalai/utils/checkpoint_io/distributed.py @@ -0,0 +1,127 @@ +import torch +from numpy import prod +from torch import Tensor +from typing import List, Optional, Tuple +from collections import defaultdict +from .meta import ParamDistMeta, ParamRedistMeta + + +def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.' + if not dist_metas[0].used_zero: + # tensors are replicate + return tensors[0] + numel = dist_metas[0].zero_numel + orig_shape = dist_metas[0].zero_orig_shape + tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)] + assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.' + return torch.cat(tensors).reshape(orig_shape) + + +def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + for dist_meta in dist_metas[1:]: + assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.' + for t in tensors[1:]: + assert t.shape == tensors[0].shape, 'Expect all params have the same shape.' + if not dist_metas[0].used_tp: + # tensors are replicate + return tensors[0] + total_parts = prod(dist_meta.tp_num_parts) + assert dist_meta.tp_world_size == total_parts, \ + f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.' + shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True) + for dim, num_parts in shard_info: + buffer = [] + for start in range(0, len(tensors), num_parts): + buffer.append(torch.cat(tensors[start:start + num_parts], dim)) + tensors = buffer + assert len(tensors) == 1 + return tensors[0] + + +def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None: + assert len(dist_metas) > 0 + # check world size + for dist_meta in dist_metas[1:]: + assert dist_meta.dp_world_size == dist_metas[ + 0].dp_world_size, 'Expect all dist meta have the same dp_world_size' + assert dist_meta.tp_world_size == dist_metas[ + 0].tp_world_size, 'Expect all dist meta have the same tp_world_size' + + +def deduplicate_params(tensors: List[Tensor], + dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]: + unique_dist_meta = [] + unique_idx = [] + for i, dist_meta in enumerate(dist_metas): + if dist_meta not in unique_dist_meta: + unique_dist_meta.append(dist_meta) + unique_idx.append(i) + return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx] + + +def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor: + assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas) + # validate parallel info + validate_parallel_info(dist_metas) + tensors, dist_metas = deduplicate_params(tensors, dist_metas) + unflattened_tensors = [] + # group zero params by tp rank + tensor_dict = defaultdict(list) + dist_meta_dict = defaultdict(list) + for t, dist_meta in zip(tensors, dist_metas): + tensor_dict[dist_meta.tp_rank].append(t) + dist_meta_dict[dist_meta.tp_rank].append(dist_meta) + assert len(tensor_dict + ) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}' + for tp_rank in tensor_dict.keys(): + unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank])) + return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()]) + + +def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_tp: + assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.' + return [tensor] + total_parts = prod(redist_meta.tp_num_parts) + assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.' + shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0]) + tensors = [tensor] + for dim, num_parts in shard_info: + buffer = [] + for t in tensors: + assert t.size(dim) % num_parts == 0, \ + f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.' + chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)] + buffer.extend(chunks) + tensors = buffer + assert len(tensors) == redist_meta.tp_world_size + return tensors + + +def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]: + if not redist_meta.used_zero: + return [tensor] * redist_meta.dp_world_size + tensors: List[Optional[Tensor]] = [ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank) + ] + offsets = redist_meta.zero_offsets + [tensor.numel()] + for i, offset in enumerate(offsets[:-1]): + end = offsets[i + 1] + tensors.append(tensor.view(-1)[offset:end]) + if len(tensors) < redist_meta.dp_world_size: + tensors.extend([ + torch.empty(0, dtype=tensor.dtype, device=tensor.device) + for _ in range(redist_meta.dp_world_size - len(tensors)) + ]) + assert len(tensors) == redist_meta.dp_world_size + return tensors + + +def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]: + tensors = split_tp_param(tensor, redist_meta) + tensors = [flatten_zero_param(t, redist_meta) for t in tensors] + return tensors diff --git a/colossalai/utils/checkpoint_io/io.py b/colossalai/utils/checkpoint_io/io.py new file mode 100644 index 000000000..f00212cdf --- /dev/null +++ b/colossalai/utils/checkpoint_io/io.py @@ -0,0 +1,170 @@ +import warnings +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple + +import torch.distributed as dist +from torch.nn import Module +from torch.optim import Optimizer + +from .backend import get_backend +from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger, + OptimizerCheckpointRedistor) +from .meta import ParamDistMeta, RedistMeta +from .utils import build_checkpoints, optimizer_load_state_dict + + +def save(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk', + **kwargs: Any) -> None: + io_backend = get_backend(backend) + if dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + if world_size == 1: + # global doesn't need dist_meta + dist_meta = None + else: + assert dist_meta is not None + max_shard_size = int(max_shard_size_gb * 1024**3) + model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer, + param_to_os, dist_meta) + writer = io_backend.get_writer(path, overwrite, rank, world_size) + writer.save_others(kwargs) + for model_checkpoint in model_checkpoints: + writer.save_model(model_checkpoint) + for optimizer_checkpoint in optimizer_checkpoints: + writer.save_optimizer(optimizer_checkpoint) + writer.save_meta(meta_checkpoint) + + +def merge(path: str, + output_path: str, + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + reader = io_backend.get_reader(path) + if len(reader.meta_list) == 1: + # already global + warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.') + return False + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + writer = io_backend.get_writer(output_path, overwrite=overwrite) + writer.save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(), + dist_meta_list) + _convert_shards( + OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os), + reader.load_optimizers(), dist_meta_list) + meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def redist(path: str, + output_path: str, + redist_meta: RedistMeta, + dist_metas: List[Dict[str, ParamDistMeta]], + max_shard_size_gb: float = 0.0, + overwrite: bool = False, + backend: str = 'disk') -> bool: + io_backend = get_backend(backend) + if dist.is_initialized() and dist.get_rank() != 0: + return False + nprocs = len(dist_metas) + reader = io_backend.get_reader(path) + dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta() + do_redist: bool = False + if len(dist_meta_list) == nprocs: + for a, b in zip(dist_metas, dist_meta_list): + if a != b: + do_redist = True + break + else: + do_redist = True + if not do_redist: + warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.') + return False + + writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)] + writers[0].save_others(reader.load_others()) + max_shard_size = int(max_shard_size_gb * 1024**3) + _convert_shards( + ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta), + reader.load_models(), dist_meta_list) + _convert_shards( + OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count, + param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list) + for writer, dist_meta in zip(writers, dist_metas): + meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())} + if param_to_os is not None: + meta_checkpoint['param_to_os'] = param_to_os + meta_checkpoint['paired_os'] = paired_os + writer.save_meta(meta_checkpoint) + return True + + +def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None], + dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None: + for shard_dict in shard_generator: + convertor.append(shard_dict, dist_meta_list) + convertor.complete() + + +def load(path: str, + model: Module, + optimizer: Optional[Optimizer] = None, + redist_meta: Optional[RedistMeta] = None, + dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None, + max_shard_size_gb: float = 0.0, + backend: str = 'disk') -> dict: + is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1 + rank: int = dist.get_rank() if dist.is_initialized() else 0 + is_main_process: bool = rank == 0 + # validate args + if redist_meta is None or dist_metas is None: + assert is_global + io_backend = get_backend(backend) + read_path: str = path + if is_main_process: + # pre-process checkpoints + temp_path = io_backend.get_temp(path) + if is_global: + wrote = merge(path, temp_path, max_shard_size_gb, backend=backend) + else: + wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend) + if wrote: + read_path = temp_path + if not is_global: + bcast_list = [read_path] if is_main_process else [None] + dist.broadcast_object_list(bcast_list) + read_path = bcast_list[0] + reader = io_backend.get_reader(read_path) + # load model + for shard in reader.load_model(rank): + model.load_state_dict(shard, strict=False) + if optimizer is not None: + for shard in reader.load_optimizer(rank): + # optimizer.load_state_dict(shard) + optimizer_load_state_dict(optimizer, shard) + others_dict = reader.load_others() + if not is_global: + dist.barrier() + # clean up temp + if is_main_process: + io_backend.clean_temp() + return others_dict diff --git a/colossalai/utils/checkpoint_io/meta.py b/colossalai/utils/checkpoint_io/meta.py new file mode 100644 index 000000000..994f08b4b --- /dev/null +++ b/colossalai/utils/checkpoint_io/meta.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import List, Optional, Set, Dict + + +@dataclass +class ParamDistMeta: + # parallel info + dp_rank: int + dp_world_size: int + tp_rank: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_numel: Optional[int] = None + zero_orig_shape: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_numel is not None and self.zero_orig_shape is not None + + @property + def parallel_meta(self) -> tuple: + return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size + + @property + def tp_meta(self) -> tuple: + return self.tp_shard_dims, self.tp_num_parts + + @property + def zero_meta(self) -> tuple: + return self.zero_numel, self.zero_orig_shape + + @staticmethod + def from_dict(d: dict) -> 'ParamDistMeta': + return ParamDistMeta(**d) + + +@dataclass +class ParamRedistMeta: + # parallel info + dp_world_size: int + tp_world_size: int + # tp info + tp_shard_dims: Optional[List[int]] = None + tp_num_parts: Optional[List[int]] = None + # zero info + zero_start_dp_rank: Optional[int] = None + zero_offsets: Optional[List[int]] = None + + @property + def used_tp(self) -> bool: + return self.tp_shard_dims is not None and self.tp_num_parts is not None + + @property + def used_zero(self) -> bool: + return self.zero_start_dp_rank is not None and self.zero_offsets is not None + + +@dataclass +class RankRedistMeta: + dp_rank: int + tp_rank: int + pp_rank: int + + +@dataclass +class PipelineRedistMeta: + params: Set[str] + + +@dataclass +class RedistMeta: + rank_meta: Dict[str, Dict[int, RankRedistMeta]] + pipeline_meta: List[PipelineRedistMeta] + param_meta: Dict[str, ParamRedistMeta] diff --git a/colossalai/utils/checkpoint_io/reader.py b/colossalai/utils/checkpoint_io/reader.py new file mode 100644 index 000000000..3158c6481 --- /dev/null +++ b/colossalai/utils/checkpoint_io/reader.py @@ -0,0 +1,131 @@ +import os +from abc import ABC, abstractmethod +from collections import Counter +from typing import Dict, Generator, List, Optional, Tuple + +import torch + +from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME +from .meta import ParamDistMeta +from .utils import is_duplicated_list + + +class CheckpointReader(ABC): + + def __init__(self, base_name: str) -> None: + super().__init__() + self.base_name = base_name + self.meta_list = [] + + @abstractmethod + def read(self, name: str) -> dict: + pass + + @abstractmethod + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + pass + + @abstractmethod + def load_model(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_models(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + pass + + @abstractmethod + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + pass + + @abstractmethod + def load_others(self) -> dict: + pass + + +class DiskCheckpointReader(CheckpointReader): + + def __init__(self, base_name: str) -> None: + super().__init__(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + global_meta = self.read(GLOBAL_META_FILE_NAME) + for meta_file_name in global_meta['meta']: + meta = self.read(meta_file_name) + if meta.get('dist_meta', None) is None: + # only global checkpoint can have empty dist_meta + assert len(global_meta['meta']) == 1 + self.meta_list.append(meta) + + def read(self, name: str) -> dict: + return torch.load(os.path.join(self.base_name, name)) + + def load_meta( + self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]: + meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os', + None), meta.get('paired_os', None)) + for meta in self.meta_list] + dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos) + # reduce param_count + param_count = Counter(p for params in params_list for p in params) + # validate param_to_os + assert is_duplicated_list(param_to_os_list) + assert is_duplicated_list(paired_os_list) + return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0] + + def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]: + meta = self.meta_list[rank] + checkpoint_names = meta.get(shard_type, []) + for name in checkpoint_names: + yield self.read(name) + + def load_model(self, rank: int) -> Generator[dict, None, None]: + return self._load_shard('model', rank) + + def load_models(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + model_checkpoint_names = meta.get('model', []) + if indices[i] < len(model_checkpoint_names): + shards[i] = self.read(model_checkpoint_names[indices[i]]) + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_optimizer(self, rank: int) -> Generator[dict, None, None]: + param_groups = None + for shard in self._load_shard('optimizer', rank): + if param_groups is None: + param_groups = shard['param_groups'] + else: + shard['param_groups'] = param_groups + yield shard + + def load_optimizers(self) -> Generator[Dict[int, dict], None, None]: + indices = [0] * len(self.meta_list) + param_groups = [] + while True: + shards = {} + for i, meta in enumerate(self.meta_list): + optimizer_checkpoint_names = meta.get('optimizer', []) + if indices[i] < len(optimizer_checkpoint_names): + shards[i] = self.read(optimizer_checkpoint_names[indices[i]]) + if indices[i] == 0: + param_groups.append(shards[i]['param_groups']) + else: + shards[i]['param_groups'] = param_groups[i] + indices[i] += 1 + if len(shards) > 0: + yield shards + else: + break + + def load_others(self) -> dict: + return self.read(OTHER_CKPT_FILE_NAME) diff --git a/colossalai/utils/checkpoint_io/utils.py b/colossalai/utils/checkpoint_io/utils.py new file mode 100644 index 000000000..135385f57 --- /dev/null +++ b/colossalai/utils/checkpoint_io/utils.py @@ -0,0 +1,223 @@ +import warnings +from copy import deepcopy +from itertools import chain +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import Tensor +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.optim import Optimizer + +from .meta import ParamDistMeta + + +def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any: + if arg is not None: + return fn(arg) + + +def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]: + # ensure all params in optimizer are in model state dict + params_set = set(id(p) for p in model.parameters()) + for group in optimizer.param_groups: + for p in group['params']: + assert id(p) in params_set + param_mappings = {} + start_index = 0 + + def get_group_mapping(group): + nonlocal start_index + param_mappings.update( + {id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings}) + start_index += len(group['params']) + + for g in optimizer.param_groups: + get_group_mapping(g) + return {k: param_mappings[id(p)] for k, p in model.named_parameters()} + + +def compute_optimizer_state_size(state: Dict[str, Any]) -> int: + size = 0 + for v in state.values(): + if isinstance(v, Tensor): + size += v.numel() * v.element_size() + return size + + +class ModelCheckpointSharder: + + def __init__(self, max_shard_size: int) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, Tensor] = {} + self.buffer_size: int = 0 + + def append(self, key: str, tensor: Tensor) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {} + self.buffer_size = 0 + self.buffer[key] = tensor + self.buffer_size += tensor.numel() * tensor.element_size() + return retval + + def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]: + shards = [] + for key, tensor in state_dict.items(): + shard = self.append(key, tensor) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer) > 0 else None + + +class OptimizerCheckpointSharder: + + def __init__(self, max_shard_size: int, param_groups: dict) -> None: + self.max_shard_size = max_shard_size + self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups} + self.buffer_size: int = 0 + self.returned_first: bool = False + + def append(self, key: int, state: dict) -> Optional[dict]: + retval = None + if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size: + retval = self.buffer + self.buffer = {'state': {}} + self.buffer_size = 0 + self.buffer['state'][key] = state + self.buffer_size += compute_optimizer_state_size(state) + return retval + + def extend(self, state_dict: Dict[str, dict]) -> List[dict]: + shards = [] + for key, state in state_dict['state'].items(): + shard = self.append(key, state) + run_if_not_none(shards.append, shard) + return shards + + def complete(self) -> Optional[dict]: + return self.buffer if len(self.buffer['state']) > 0 else None + + +def shard_checkpoint(max_shard_size: int, + model_state_dict: Dict[str, Tensor], + optimizer_state_dict: Optional[dict] = None, + param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]: + has_optimizer: bool = False + if optimizer_state_dict is not None: + assert param_to_os is not None + os_to_param = {v: k for k, v in param_to_os.items()} + for os_key in optimizer_state_dict['state'].keys(): + assert os_key in os_to_param + assert os_to_param[os_key] in model_state_dict + has_optimizer = True + model_sharder = ModelCheckpointSharder(max_shard_size) + model_shards = model_sharder.extend(model_state_dict) + run_if_not_none(model_shards.append, model_sharder.complete()) + if not has_optimizer: + return model_shards, [] + optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups']) + optimizer_shards = optimizer_sharder.extend(optimizer_state_dict) + run_if_not_none(optimizer_shards.append, optimizer_sharder.complete()) + return model_shards, optimizer_shards + + +def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict: + os_to_param = {v: k for k, v in param_to_os.items()} + paired_os = {} + for idx, state in optimizer_state_dict['state'].items(): + paired_os[idx] = {} + p = model_state_dict[os_to_param[idx]] + for k, v in state.items(): + if isinstance(v, Tensor) and v.shape == p.shape: + paired_os[idx][k] = True + else: + paired_os[idx][k] = False + return paired_os + + +def build_checkpoints(max_size: int, + model: Module, + optimizer: Optional[Optimizer] = None, + param_to_os: Optional[Dict[str, int]] = None, + dist_meta: Optional[Dict[str, ParamDistMeta]] = None, + eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]: + save_global = dist_meta is None + model_state_dict = model.state_dict() + optimizer_state_dict = optimizer.state_dict() if optimizer else None + meta = {'dist_meta': dist_meta} + if optimizer: + param_to_os = param_to_os or get_param_to_os(model, optimizer) + paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os) + meta['param_to_os'] = param_to_os + meta['paired_os'] = paired_os + if not save_global and eliminate_replica: + # filter dp replicated params + model_state_dict = { + k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + if optimizer: + optimizer_state_dict['state'] = { + param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]] + for k in model_state_dict.keys() + if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0 + } + meta['params'] = list(model_state_dict.keys()) + if len(model_state_dict) == 0: + warnings.warn('model state dict is empty, checkpoint is not saved') + return [], [], meta + model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict, + param_to_os) + return model_checkpoints, optimizer_checkpoints, meta + + +def is_duplicated_list(list_: List[Any]) -> bool: + if len(list_) == 0: + return True + elem = list_[0] + for x in list_[1:]: + if x != elem: + return False + return True + + +def copy_optimizer_state(src_state: dict, dest_state: dict) -> None: + for k, v in src_state.items(): + if k in dest_state: + old_v = dest_state[k] + if isinstance(old_v, Tensor): + old_v.copy_(v) + else: + dest_state[k] = v + + +def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None: + assert optimizer.state_dict()['param_groups'] == state_dict['param_groups'] + state_dict = deepcopy(state_dict) + groups = optimizer.param_groups + saved_groups = state_dict['param_groups'] + idx_to_p: Dict[str, Parameter] = { + old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups + )), chain.from_iterable((g['params'] for g in groups))) + } + missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys())) + unexpected_keys = [] + error_msgs = [] + for idx, state in state_dict['state'].items(): + if idx in idx_to_p: + old_state = optimizer.state[idx_to_p[idx]] + copy_optimizer_state(state, old_state) + else: + unexpected_keys.append(idx) + if strict: + if len(unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys))) + if len(missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys))) + if len(error_msgs) > 0: + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__, + "\n\t".join(error_msgs))) diff --git a/colossalai/utils/checkpoint_io/writer.py b/colossalai/utils/checkpoint_io/writer.py new file mode 100644 index 000000000..4552accde --- /dev/null +++ b/colossalai/utils/checkpoint_io/writer.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Optional +from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME +import torch +import os + + +class CheckpointWriter(ABC): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__() + self.base_name = base_name + self.overwrite = overwrite + self.rank = rank + self.world_size = world_size + self.is_distributed = world_size > 1 + self.is_main_process = rank == 0 + + @abstractmethod + def write(self, name: str, state_dict: dict) -> None: + pass + + @abstractmethod + def save_model(self, model_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_meta(self, meta_checkpoint: dict) -> None: + pass + + @abstractmethod + def save_others(self, kwargs: dict) -> None: + pass + + +class DiskCheckpointWriter(CheckpointWriter): + + def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None: + super().__init__(base_name, overwrite, rank, world_size) + if not os.path.exists(base_name): + os.makedirs(base_name) + assert os.path.isdir(base_name), f'"{base_name}" is not a directory' + self.model_checkpoint_names = [] + self.optimizer_checkpoint_names = [] + self.is_meta_saved: bool = False + self._save_global_meta() + + def write(self, name: str, state_dict: dict) -> None: + path = os.path.join(self.base_name, name) + if os.path.exists(path) and not self.overwrite: + raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)') + torch.save(state_dict, path) + + def _save_global_meta(self) -> None: + if self.is_main_process: + global_meta = {'meta': []} + if self.is_distributed: + for i in range(self.world_size): + global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin')) + else: + global_meta['meta'].append(META_CKPT_FILE_NAME) + self.write(GLOBAL_META_FILE_NAME, global_meta) + + def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str: + checkpoint_name = base_name + if self.is_distributed: + checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin') + if shard_idx is not None: + checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin') + return checkpoint_name + + def save_model(self, model_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save model after saving meta' + name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names)) + self.write(name, model_checkpoint) + self.model_checkpoint_names.append(name) + + def save_optimizer(self, optimizer_checkpoint: dict) -> None: + assert not self.is_meta_saved, 'Cannot save optimizer after saving meta' + name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names)) + self.write(name, optimizer_checkpoint) + self.optimizer_checkpoint_names.append(name) + + def save_meta(self, meta_checkpoint: dict) -> None: + if len(self.model_checkpoint_names) > 0: + meta_checkpoint['model'] = self.model_checkpoint_names + if len(self.optimizer_checkpoint_names) > 0: + meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names + self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint) + self.is_meta_saved = True + + def save_others(self, kwargs: dict) -> None: + if self.is_main_process: + self.write(OTHER_CKPT_FILE_NAME, kwargs) diff --git a/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py new file mode 100644 index 000000000..6d89fb90c --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_build_checkpoints.py @@ -0,0 +1,120 @@ +import torch +import torch.nn as nn +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.utils import build_checkpoints +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def test_global_model(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model) + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + global_state_dict = model_checkpoints[0] + assert set(orig_state_dict.keys()) == set(global_state_dict.keys()) + for k, v in orig_state_dict.items(): + assert torch.equal(v, global_state_dict[k]) + + +def test_global_model_shard(): + model = DummyModel() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model) + assert len(model_checkpoints) == 2 + assert len(optimizer_checkpoints) == 0 + assert meta['dist_meta'] is None + orig_state_dict = model.state_dict() + assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys()) + assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0 + for k, v in orig_state_dict.items(): + for state_dict in model_checkpoints: + if k in state_dict: + assert torch.equal(v, state_dict[k]) + + +def test_global_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer) + assert len(optimizer_checkpoints) == 1 + assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1} + for state in meta['paired_os'].values(): + for k, is_paired in state.items(): + if k == 'step': + assert not is_paired + else: + assert is_paired + orig_state_dict = optimizer.state_dict() + state_dict = optimizer_checkpoints[0] + for k, orig_state in orig_state_dict['state'].items(): + state = state_dict['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v2 == v2 + assert orig_state_dict['param_groups'] == state_dict['param_groups'] + + +def test_global_optimizer_shard(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer) + assert len(optimizer_checkpoints) == 2 + assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1] + orig_state_dict = optimizer.state_dict() + assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set( + optimizer_checkpoints[1]['state'].keys()) + assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0 + for k, orig_state in orig_state_dict['state'].items(): + state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][ + 'state'] else optimizer_checkpoints[1]['state'][k] + for v1, v2 in zip(orig_state.values(), state.values()): + if isinstance(v2, torch.Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + + assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups'] + + +def test_dist_model_optimizer(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0] + assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state'] + dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)} + model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta) + assert dist_meta == meta['dist_meta'] + assert len(model_checkpoints) == 1 + assert len(optimizer_checkpoints) == 1 + + +if __name__ == '__main__': + test_global_model() + test_global_model_shard() + test_global_optimizer() + test_global_optimizer_shard() + test_dist_model_optimizer() diff --git a/tests/test_utils/test_checkpoint_io/test_load.py b/tests/test_utils/test_checkpoint_io/test_load.py new file mode 100644 index 000000000..780c13dc5 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_load.py @@ -0,0 +1,188 @@ +from copy import deepcopy +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.io import load, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, RankRedistMeta, RedistMeta) +from torch import Tensor +from torch.nn import Module +from torch.optim import Adam, Optimizer + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.rand_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def reset_model_optim(model: Module, optimizer: Optimizer, scalar: float = 0.0): + with torch.no_grad(): + for p in model.parameters(): + p.fill_(scalar) + for state in optimizer.state.values(): + for v in state.values(): + if isinstance(v, Tensor): + v.fill_(scalar) + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +@pytest.mark.parametrize('max_shard_size_gb', [80 / 1024**3, 0]) +def test_save_global_load_global(max_shard_size_gb: float): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=max_shard_size_gb) + new_model, new_optimizer = prepare_model_optim() + load(dir_name, new_model, new_optimizer, max_shard_size_gb=max_shard_size_gb) + check_model_state_dict(model.state_dict(), new_model.state_dict()) + check_optim_state_dict(optimizer.state_dict(), new_optimizer.state_dict()) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def launch_dist(fn, world_size: int): + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + + +def save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + reset_model_optim(model, optmizer) + world_size = dist.get_world_size() + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(world_size, zero)[rank]) + + +def load_and_check_dist(dir_name: str): + world_size = dist.get_world_size() + model, optmizer = prepare_model_optim(shard=True) + reset_model_optim(model, optmizer) + model_state_dict = deepcopy(model.state_dict()) + optimizer_state_dict = deepcopy(optmizer.state_dict()) + reset_model_optim(model, optmizer, 1) + load(dir_name, model, optmizer, get_redist_meta(world_size), get_dist_metas(world_size)) + check_model_state_dict(model_state_dict, model.state_dict()) + check_optim_state_dict(optimizer_state_dict, optmizer.state_dict()) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_global_load_dist(): + model, optimizer = prepare_model_optim() + reset_model_optim(model, optimizer) + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist_load_dist(): + with TemporaryDirectory() as dir_name: + # save tp + dp + fn = partial(save_dist, dir_name, False) + launch_dist(fn, 2) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + with TemporaryDirectory() as dir_name: + # save tp + zero + fn = partial(save_dist, dir_name, True) + launch_dist(fn, 4) + # load tp + dp + fn = partial(load_and_check_dist, dir_name) + launch_dist(fn, 2) + launch_dist(fn, 4) + + +if __name__ == '__main__': + test_save_global_load_global(80 / 1024**3) + test_save_global_load_global(0) + test_save_global_load_dist() + test_save_dist_load_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_merge.py b/tests/test_utils/test_checkpoint_io/test_merge.py new file mode 100644 index 000000000..04e454dcb --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge.py @@ -0,0 +1,127 @@ +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import save, merge +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from tempfile import TemporaryDirectory +from torch.optim import Adam +from functools import partial +import torch +import os +import pytest +import colossalai +import torch.nn as nn +import torch.distributed as dist +import torch.multiprocessing as mp + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_merge_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 0 + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + dp_world_size = dist.get_world_size() // 2 + if not zero: + dist_metas = { + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + } + else: + dist_metas = { + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_merge_tp_dp(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + merge(dir_name, output_dir) + assert len(os.listdir(output_dir)) == 5 + global_meta = torch.load(os.path.join(output_dir, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 + meta = torch.load(os.path.join(output_dir, global_meta['meta'][0])) + assert meta['dist_meta'] is None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(output_dir, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 20 + optimizer_state_dict = torch.load(os.path.join(output_dir, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 20 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 20 + + +if __name__ == '__main__': + test_merge_global() + test_merge_tp_dp(False) + test_merge_tp_dp(True) diff --git a/tests/test_utils/test_checkpoint_io/test_merge_param.py b/tests/test_utils/test_checkpoint_io/test_merge_param.py new file mode 100644 index 000000000..5da2ae4fe --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_merge_param.py @@ -0,0 +1,101 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from colossalai.utils.checkpoint_io.distributed import unflatten_zero_param, gather_tp_param, merge_param + + +def test_unflatten_zero_param_even() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_unflatten_zero_param_uneven() -> None: + dist_metas = [ParamDistMeta(i, 4, 0, 1, zero_numel=16, zero_orig_shape=[4, 4]) for i in range(1, 3)] + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + unflattened_tensor = unflatten_zero_param(tensors, dist_metas) + assert torch.equal(orig_tensor, unflattened_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_row() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[0], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_1d_col() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 4, tp_shard_dims=[1], tp_num_parts=[4]) for i in range(4)] + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_gather_tp_param_2d_reverse() -> None: + dist_metas = [ParamDistMeta(0, 1, i, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) for i in range(6)] + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + gathered_tensor = gather_tp_param(tensors, dist_metas) + assert torch.equal(orig_tensor, gathered_tensor) + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_hybrid() -> None: + dist_metas = [ + ParamDistMeta(i % 2, + 2, + i // 2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_numel=4, + zero_orig_shape=[2, 2]) for i in range(12) + ] + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + merged_tensor = merge_param(tensors, dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +def test_merge_param_dummy() -> None: + dist_metas = [ParamDistMeta(0, 1, 0, 1)] + orig_tensor = torch.rand(4, 6) + merged_tensor = merge_param([orig_tensor], dist_metas) + assert torch.equal(orig_tensor, merged_tensor) + + +if __name__ == '__main__': + test_unflatten_zero_param_even() + test_unflatten_zero_param_uneven() + test_gather_tp_param_1d_row() + test_gather_tp_param_1d_col() + test_gather_tp_param_2d() + test_gather_tp_param_2d_reverse() + test_merge_param_hybrid() + test_merge_param_dummy() diff --git a/tests/test_utils/test_checkpoint_io/test_redist.py b/tests/test_utils/test_checkpoint_io/test_redist.py new file mode 100644 index 000000000..6e76f3167 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_redist.py @@ -0,0 +1,149 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import GLOBAL_META_FILE_NAME +from colossalai.utils.checkpoint_io.io import redist, save +from colossalai.utils.checkpoint_io.meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, + RedistMeta) +from torch.optim import Adam + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(shard: bool = False, zero: bool = False): + model = DummyModel() + if shard: + model.fc.weight.data = model.fc.weight.chunk(2, 1)[dist.get_rank() % 2] + if zero: + dp_rank = dist.get_rank() // 2 + model.fc.weight.data = model.fc.weight.reshape(-1).split([3, model.fc.weight.size(1) - 3], 0)[dp_rank] + if dp_rank != 0: + model.fc.bias.data = torch.empty(0, dtype=model.fc.bias.dtype) + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def get_dist_metas(nprocs: int, zero: bool = False): + dp_world_size = nprocs // 2 + dist_metas = [] + for rank in range(nprocs): + if zero: + dist_metas.append({ + 'fc.weight': + ParamDistMeta(rank // 2, + dp_world_size, + rank % 2, + 2, + tp_shard_dims=[1], + tp_num_parts=[2], + zero_numel=10, + zero_orig_shape=[1, 10]), + 'fc.bias': + ParamDistMeta(rank // 2, dp_world_size, 0, 1, zero_numel=1, zero_orig_shape=[1]) + }) + else: + dist_metas.append({ + 'fc.weight': ParamDistMeta(rank // 2, dp_world_size, rank % 2, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamDistMeta(rank // 2, dp_world_size, 0, 1) + }) + return dist_metas + + +def get_redist_meta(nprocs: int): + dp_world_size = nprocs // 2 + rank_meta = { + 'fc.weight': {rank: RankRedistMeta(rank // 2, rank % 2, 0) for rank in range(nprocs)}, + 'fc.bias': {rank: RankRedistMeta(rank // 2, 0, 0) for rank in range(nprocs)} + } + param_meta = { + 'fc.weight': ParamRedistMeta(dp_world_size, 2, tp_shard_dims=[1], tp_num_parts=[2]), + 'fc.bias': ParamRedistMeta(dp_world_size, 1) + } + return RedistMeta(rank_meta, [], param_meta) + + +def check_checkpoint_shape(dir_name: str): + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + for meta_name in global_meta['meta']: + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta['dist_meta'] is not None + assert len(meta['params']) == 2 + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + assert model_state_dict['fc.weight'].size(1) == 10 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict and 'state' in optimizer_state_dict + assert optimizer_state_dict['state'][0]['exp_avg'].size(1) == 10 + assert optimizer_state_dict['state'][0]['exp_avg_sq'].size(1) == 10 + + +def test_global_to_dist(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + check_checkpoint_shape(output_dir) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={'parallel': { + 'tensor': { + 'mode': '1d', + 'size': 2 + } + }}, + rank=rank, + world_size=world_size, + host='localhost', + port=port, + backend='nccl') + func() + + +def run_save_dist(dir_name: str, zero: bool): + model, optmizer = prepare_model_optim(shard=True, zero=zero) + rank = dist.get_rank() + save(dir_name, model, optmizer, dist_meta=get_dist_metas(4, zero)[rank]) + + +@pytest.mark.dist +@pytest.mark.parametrize("zero", [False, True]) +@rerun_if_address_is_in_use() +def test_dist_to_dist(zero: bool): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name, zero) + world_size = 4 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + with TemporaryDirectory() as output_dir: + redist(dir_name, output_dir, get_redist_meta(4), get_dist_metas(4)) + if not zero: + assert len(os.listdir(output_dir)) == 0 + else: + check_checkpoint_shape(output_dir) + + +if __name__ == '__main__': + test_global_to_dist() + test_dist_to_dist(False) + test_dist_to_dist(True) diff --git a/tests/test_utils/test_checkpoint_io/test_save.py b/tests/test_utils/test_checkpoint_io/test_save.py new file mode 100644 index 000000000..5ff9d0aa2 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_save.py @@ -0,0 +1,147 @@ +import os +from functools import partial +from tempfile import TemporaryDirectory +from typing import Dict + +import colossalai +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port +from colossalai.utils.checkpoint_io.constant import (GLOBAL_META_FILE_NAME, META_CKPT_FILE_NAME, MODEL_CKPT_FILE_NAME, + OTHER_CKPT_FILE_NAME) +from colossalai.utils.checkpoint_io.io import save +from colossalai.utils.checkpoint_io.meta import ParamDistMeta +from torch import Tensor +from torch.optim import Adam + + +def check_model_state_dict(a: Dict[str, Tensor], b: Dict[str, Tensor]) -> None: + assert set(a.keys()) == set(b.keys()) + for k, v in a.items(): + assert torch.equal(v, b[k]) + + +def check_optim_state_dict(a: dict, b: dict, ignore_param_gruops: bool = False) -> None: + assert set(a['state'].keys()) == set(b['state'].keys()) + for k, state in a['state'].items(): + b_state = b['state'][k] + for v1, v2 in zip(state.values(), b_state.values()): + if isinstance(v1, Tensor): + assert torch.equal(v1, v2) + else: + assert v1 == v2 + if not ignore_param_gruops: + assert a['param_groups'] == b['param_groups'] + + +class DummyModel(nn.Module): + + def __init__(self) -> None: + super().__init__() + self.fc = nn.Linear(20, 1) + + +def prepare_model_optim(): + model = DummyModel() + for p in model.parameters(): + p.grad = torch.ones_like(p) + optimizer = Adam(model.parameters(), lr=1e-3) + optimizer.step() + return model, optimizer + + +def test_overwrite(): + model = DummyModel() + with TemporaryDirectory() as dir_name: + with open(os.path.join(dir_name, MODEL_CKPT_FILE_NAME.replace('.bin', '-shard0.bin')), 'a') as f: + pass + with pytest.raises(RuntimeError, match=r'Save error: Checkpoint ".+" exists\. \(overwrite = False\)'): + save(dir_name, model) + + +def test_save_global(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer) + assert len(os.listdir(dir_name)) == 5 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 1 and global_meta['meta'][0] == META_CKPT_FILE_NAME + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 1 + assert len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + check_model_state_dict(model.state_dict(), model_state_dict) + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + check_optim_state_dict(optimizer.state_dict(), optimizer_state_dict) + other_state_dict = torch.load(os.path.join(dir_name, OTHER_CKPT_FILE_NAME)) + assert len(other_state_dict) == 0 + + +def test_save_global_shard(): + model, optimizer = prepare_model_optim() + with TemporaryDirectory() as dir_name: + save(dir_name, model, optimizer, max_shard_size_gb=80 / 1024**3) + assert len(os.listdir(dir_name)) == 7 + meta = torch.load(os.path.join(dir_name, META_CKPT_FILE_NAME)) + assert len(meta['model']) == 2 and len(meta['optimizer']) == 2 + model_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['model']] + assert len(set(model_state_dicts[0].keys()) & set(model_state_dicts[1].keys())) == 0 + check_model_state_dict(model.state_dict(), {**model_state_dicts[0], **model_state_dicts[1]}) + optimizer_state_dicts = [torch.load(os.path.join(dir_name, name)) for name in meta['optimizer']] + assert len(set(optimizer_state_dicts[0]['state'].keys()) & set(optimizer_state_dicts[1]['state'].keys())) == 0 + assert 'param_groups' in optimizer_state_dicts[0] and 'param_groups' not in optimizer_state_dicts[1] + check_optim_state_dict( + optimizer.state_dict(), { + 'state': { + **optimizer_state_dicts[0]['state'], + **optimizer_state_dicts[1]['state'] + }, + 'param_groups': optimizer_state_dicts[0]['param_groups'] + }) + + +def run_dist(rank, world_size, port, func): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + func() + + +def run_save_dist(dir_name): + model, optmizer = prepare_model_optim() + dist_metas = { + 'fc.weight': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1), + 'fc.bias': ParamDistMeta(dist.get_rank(), dist.get_world_size(), 0, 1) + } + save(dir_name, model, optmizer, dist_meta=dist_metas) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_save_dist(): + with TemporaryDirectory() as dir_name: + fn = partial(run_save_dist, dir_name) + world_size = 2 + proc_fn = partial(run_dist, world_size=world_size, port=free_port(), func=fn) + mp.spawn(proc_fn, nprocs=world_size) + assert len(os.listdir(dir_name)) == 8 + global_meta = torch.load(os.path.join(dir_name, GLOBAL_META_FILE_NAME)) + assert len(global_meta['meta']) == 2 + for rank, meta_name in enumerate(global_meta['meta']): + meta = torch.load(os.path.join(dir_name, meta_name)) + assert meta.get('dist_meta', None) is not None + assert len(meta['model']) == 1 and len(meta['optimizer']) == 1 + model_state_dict = torch.load(os.path.join(dir_name, meta['model'][0])) + assert len(model_state_dict) == 2 + optimizer_state_dict = torch.load(os.path.join(dir_name, meta['optimizer'][0])) + assert len(optimizer_state_dict['state']) == 2 + assert 'param_groups' in optimizer_state_dict + + +if __name__ == '__main__': + test_overwrite() + test_save_global() + test_save_global_shard() + test_save_dist() diff --git a/tests/test_utils/test_checkpoint_io/test_unmerge_param.py b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py new file mode 100644 index 000000000..8b83caa12 --- /dev/null +++ b/tests/test_utils/test_checkpoint_io/test_unmerge_param.py @@ -0,0 +1,137 @@ +import torch +from colossalai.utils.checkpoint_io.meta import ParamRedistMeta +from colossalai.utils.checkpoint_io.distributed import flatten_zero_param, split_tp_param, unmerge_param + + +def test_flatten_zero_param_even() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=0, zero_offsets=[0, 4, 8, 12]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).chunk(4)) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_flatten_zero_param_uneven() -> None: + redist_meta = ParamRedistMeta(4, 1, zero_start_dp_rank=1, zero_offsets=[0, 13]) + orig_tensor = torch.rand(4, 4) + tensors = list(orig_tensor.reshape(-1).split([13, 3])) + flat_tensors = flatten_zero_param(orig_tensor, redist_meta) + assert flat_tensors[0].size(0) == 0 and flat_tensors[-1].size(0) == 0 + flat_tensors = flat_tensors[1:-1] + assert len(tensors) == len(flat_tensors) + for t, st in zip(tensors, flat_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 + unmerged_tensors = unmerged_tensors[0] + assert unmerged_tensors[0].size(0) == 0 and unmerged_tensors[-1].size(0) == 0 + unmerged_tensors = unmerged_tensors[1:-1] + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert torch.equal(t, tl) + + +def test_split_tp_param_1d_row() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[0], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 0)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_1d_col() -> None: + redist_meta = ParamRedistMeta(1, 4, tp_shard_dims=[1], tp_num_parts=[4]) + orig_tensor = torch.rand(4, 4) + tensors = [t.contiguous() for t in orig_tensor.chunk(4, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[0, 1], tp_num_parts=[2, 3]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_split_tp_param_2d_reverse() -> None: + redist_meta = ParamRedistMeta(1, 6, tp_shard_dims=[1, 0], tp_num_parts=[3, 2]) + orig_tensor = torch.rand(4, 6) + tensors = [t.contiguous() for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1)] + split_tensors = split_tp_param(orig_tensor, redist_meta) + assert len(tensors) == len(split_tensors) + for t, st in zip(tensors, split_tensors): + assert torch.equal(t, st) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(tensors) == len(unmerged_tensors) + for t, tl in zip(tensors, unmerged_tensors): + assert len(tl) == 1 + assert torch.equal(t, tl[0]) + + +def test_unmerge_param_hybrid() -> None: + redist_meta = ParamRedistMeta(2, + 6, + tp_shard_dims=[1, 0], + tp_num_parts=[3, 2], + zero_start_dp_rank=0, + zero_offsets=[0, 1]) + orig_tensor = torch.rand(4, 6) + tensors = [ + chunk for tl in orig_tensor.chunk(2, 0) for t in tl.chunk(3, 1) + for chunk in t.contiguous().reshape(-1).split([1, 3]) + ] + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 6 and len(unmerged_tensors[0]) == 2 + for tp_rank in range(6): + for dp_rank in range(2): + assert torch.equal(tensors[tp_rank * 2 + dp_rank], unmerged_tensors[tp_rank][dp_rank]) + + +def test_unmerge_param_dummy() -> None: + redist_meta = ParamRedistMeta(1, 1) + orig_tensor = torch.rand(4, 6) + unmerged_tensors = unmerge_param(orig_tensor, redist_meta) + assert len(unmerged_tensors) == 1 and len(unmerged_tensors[0]) == 1 + assert torch.equal(orig_tensor, unmerged_tensors[0][0]) + + +if __name__ == '__main__': + test_flatten_zero_param_even() + test_flatten_zero_param_uneven() + test_split_tp_param_1d_row() + test_split_tp_param_1d_col() + test_split_tp_param_2d() + test_split_tp_param_2d_reverse() + test_unmerge_param_hybrid() + test_unmerge_param_dummy()