[CheckpointIO] a uniform checkpoint I/O module (#1689)

This commit is contained in:
ver217
2022-11-08 15:15:13 +08:00
committed by GitHub
parent 629172b319
commit 99870726b1
17 changed files with 2111 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .io import load, merge, redist, save
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)

View File

@@ -0,0 +1,74 @@
import shutil
import tempfile
from abc import ABC, abstractmethod
from typing import Dict, List, Type
from .reader import CheckpointReader, DiskCheckpointReader
from .writer import CheckpointWriter, DiskCheckpointWriter
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
def register(name: str):
assert name not in _backends, f'"{name}" is registered'
def wrapper(cls):
_backends[name] = cls
return cls
return wrapper
def get_backend(name: str) -> 'CheckpointIOBackend':
assert name in _backends, f'Unsupported backend "{name}"'
return _backends[name]()
class CheckpointIOBackend(ABC):
def __init__(self) -> None:
super().__init__()
self.temps: List[str] = []
@abstractmethod
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
pass
@abstractmethod
def get_reader(self, base_name: str) -> CheckpointReader:
pass
@abstractmethod
def get_temp(self, base_name: str) -> str:
pass
@abstractmethod
def clean_temp(self) -> None:
pass
@register('disk')
class CheckpointDiskIO(CheckpointIOBackend):
def get_writer(self,
base_name: str,
overwrite: bool = False,
rank: int = 0,
world_size: int = 1) -> CheckpointWriter:
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
def get_reader(self, base_name: str) -> CheckpointReader:
return DiskCheckpointReader(base_name)
def get_temp(self, base_name: str) -> str:
temp_dir_name = tempfile.mkdtemp(dir=base_name)
self.temps.append(temp_dir_name)
return temp_dir_name
def clean_temp(self) -> None:
for temp_dir_name in self.temps:
shutil.rmtree(temp_dir_name)

View File

@@ -0,0 +1,9 @@
import re
GLOBAL_META_FILE_NAME = 'global_meta.bin'
MODEL_CKPT_FILE_NAME = 'model.bin'
OPTIM_CKPT_FILE_NAME = 'optim.bin'
META_CKPT_FILE_NAME = 'meta.bin'
OTHER_CKPT_FILE_NAME = 'other.bin'
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')

View File

@@ -0,0 +1,227 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional
from torch import Tensor
from .distributed import merge_param, unmerge_param
from .meta import ParamDistMeta, RedistMeta
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
class CheckpointConvertor(ABC):
@abstractmethod
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
pass
@abstractmethod
def complete(self) -> None:
pass
class ModelCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int]) -> None:
super().__init__()
self.param_count = param_count
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
@abstractmethod
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
for k, tensor in state_dict.items():
self.buffer[k][rank] = tensor
converted_keys = set()
for k, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[k]:
tensors = []
dist_metas = []
for rank, tensor in rank_dict.items():
tensors.append(tensor)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][k])
self.convert_tensors(k, tensors, dist_metas)
converted_keys.add(k)
for k in converted_keys:
del self.buffer[k]
def complete(self) -> None:
assert len(self.buffer) == 0
class ModelCheckpointMerger(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
super().__init__(param_count)
self.sharder = ModelCheckpointSharder(max_shard_size)
self.save_fn = save_fn
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
shard = self.sharder.append(key, tensor)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class ModelCheckpointRedistor(ModelCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
redist_meta: RedistMeta) -> None:
super().__init__(param_count)
self.save_fns = save_fns
self.redist_meta = redist_meta
nprocs = len(save_fns)
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
if len(dist_metas) == 0:
# already global
tensor = tensors[0]
else:
assert len(dist_metas) == len(tensors)
tensor = merge_param(tensors, dist_metas)
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[key][tp_rank][dp_rank]:
shard = self.sharders[rank].append(key, t)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())
class OptimizerCheckpointConvertor(CheckpointConvertor):
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__()
self.param_count = param_count
self.param_to_os = param_to_os
self.paired_os = paired_os
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
self.os_to_param = {v: k for k, v in param_to_os.items()}
@abstractmethod
def setup(self, param_groups: dict) -> None:
pass
@abstractmethod
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
pass
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for rank, state_dict in shard_dict.items():
self.setup(state_dict['param_groups'])
for idx, state in state_dict['state'].items():
self.buffer[idx][rank] = state
converted_indices = set()
for idx, rank_dict in self.buffer.items():
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
states = []
dist_metas = []
for rank, state in rank_dict.items():
states.append(state)
if dist_meta_list[rank] is not None:
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
self.convert_states(idx, states, dist_metas)
converted_indices.add(idx)
for idx in converted_indices:
del self.buffer[idx]
def complete(self) -> None:
assert len(self.buffer) == 0
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fn = save_fn
self.sharder = None
def setup(self, param_groups: dict) -> None:
if self.sharder is None:
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) == len(states)
new_state = {}
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
else:
new_state[state_key] = state_tensor
shard = self.sharder.append(idx, new_state)
run_if_not_none(self.save_fn, shard)
def complete(self) -> None:
super().complete()
run_if_not_none(self.save_fn, self.sharder.complete())
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
redist_meta: RedistMeta) -> None:
super().__init__(param_count, param_to_os, paired_os)
self.max_shard_size = max_shard_size
self.save_fns = save_fns
self.redist_meta = redist_meta
self.sharders: List[OptimizerCheckpointSharder] = []
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for k, rank_meta in redist_meta.rank_meta.items():
for rank, rank_info in rank_meta.items():
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
def setup(self, param_groups: dict) -> None:
if len(self.sharders) == 0:
nprocs = len(self.save_fns)
for _ in range(nprocs):
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
need_merge: bool = True
if len(dist_metas) == 0:
need_merge = False
else:
assert len(dist_metas) == len(states)
new_states = [{} for _ in range(len(self.save_fns))]
for state_key, state_tensor in states[0].items():
if self.paired_os[idx][state_key]:
if need_merge:
tensor = merge_param([state[state_key] for state in states], dist_metas)
else:
tensor = state_tensor
for tp_rank, tensor_list in enumerate(
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
for dp_rank, t in enumerate(tensor_list):
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
new_states[rank][state_key] = t
else:
for new_state in new_states:
new_state[state_key] = state_tensor
for rank, new_state in enumerate(new_states):
shard = self.sharders[rank].append(idx, new_state)
run_if_not_none(self.save_fns[rank], shard)
def complete(self) -> None:
super().complete()
for rank, save_fn in enumerate(self.save_fns):
run_if_not_none(save_fn, self.sharders[rank].complete())

View File

@@ -0,0 +1,127 @@
import torch
from numpy import prod
from torch import Tensor
from typing import List, Optional, Tuple
from collections import defaultdict
from .meta import ParamDistMeta, ParamRedistMeta
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
if not dist_metas[0].used_zero:
# tensors are replicate
return tensors[0]
numel = dist_metas[0].zero_numel
orig_shape = dist_metas[0].zero_orig_shape
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
return torch.cat(tensors).reshape(orig_shape)
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
for dist_meta in dist_metas[1:]:
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
for t in tensors[1:]:
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
if not dist_metas[0].used_tp:
# tensors are replicate
return tensors[0]
total_parts = prod(dist_meta.tp_num_parts)
assert dist_meta.tp_world_size == total_parts, \
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
for dim, num_parts in shard_info:
buffer = []
for start in range(0, len(tensors), num_parts):
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
tensors = buffer
assert len(tensors) == 1
return tensors[0]
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
assert len(dist_metas) > 0
# check world size
for dist_meta in dist_metas[1:]:
assert dist_meta.dp_world_size == dist_metas[
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
assert dist_meta.tp_world_size == dist_metas[
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
def deduplicate_params(tensors: List[Tensor],
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
unique_dist_meta = []
unique_idx = []
for i, dist_meta in enumerate(dist_metas):
if dist_meta not in unique_dist_meta:
unique_dist_meta.append(dist_meta)
unique_idx.append(i)
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
# validate parallel info
validate_parallel_info(dist_metas)
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
unflattened_tensors = []
# group zero params by tp rank
tensor_dict = defaultdict(list)
dist_meta_dict = defaultdict(list)
for t, dist_meta in zip(tensors, dist_metas):
tensor_dict[dist_meta.tp_rank].append(t)
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
assert len(tensor_dict
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
for tp_rank in tensor_dict.keys():
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_tp:
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
return [tensor]
total_parts = prod(redist_meta.tp_num_parts)
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
tensors = [tensor]
for dim, num_parts in shard_info:
buffer = []
for t in tensors:
assert t.size(dim) % num_parts == 0, \
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
buffer.extend(chunks)
tensors = buffer
assert len(tensors) == redist_meta.tp_world_size
return tensors
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
if not redist_meta.used_zero:
return [tensor] * redist_meta.dp_world_size
tensors: List[Optional[Tensor]] = [
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
]
offsets = redist_meta.zero_offsets + [tensor.numel()]
for i, offset in enumerate(offsets[:-1]):
end = offsets[i + 1]
tensors.append(tensor.view(-1)[offset:end])
if len(tensors) < redist_meta.dp_world_size:
tensors.extend([
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
for _ in range(redist_meta.dp_world_size - len(tensors))
])
assert len(tensors) == redist_meta.dp_world_size
return tensors
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
tensors = split_tp_param(tensor, redist_meta)
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
return tensors

View File

@@ -0,0 +1,170 @@
import warnings
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
import torch.distributed as dist
from torch.nn import Module
from torch.optim import Optimizer
from .backend import get_backend
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
OptimizerCheckpointRedistor)
from .meta import ParamDistMeta, RedistMeta
from .utils import build_checkpoints, optimizer_load_state_dict
def save(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk',
**kwargs: Any) -> None:
io_backend = get_backend(backend)
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
if world_size == 1:
# global doesn't need dist_meta
dist_meta = None
else:
assert dist_meta is not None
max_shard_size = int(max_shard_size_gb * 1024**3)
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
param_to_os, dist_meta)
writer = io_backend.get_writer(path, overwrite, rank, world_size)
writer.save_others(kwargs)
for model_checkpoint in model_checkpoints:
writer.save_model(model_checkpoint)
for optimizer_checkpoint in optimizer_checkpoints:
writer.save_optimizer(optimizer_checkpoint)
writer.save_meta(meta_checkpoint)
def merge(path: str,
output_path: str,
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
reader = io_backend.get_reader(path)
if len(reader.meta_list) == 1:
# already global
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
return False
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
writer = io_backend.get_writer(output_path, overwrite=overwrite)
writer.save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
dist_meta_list)
_convert_shards(
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
reader.load_optimizers(), dist_meta_list)
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def redist(path: str,
output_path: str,
redist_meta: RedistMeta,
dist_metas: List[Dict[str, ParamDistMeta]],
max_shard_size_gb: float = 0.0,
overwrite: bool = False,
backend: str = 'disk') -> bool:
io_backend = get_backend(backend)
if dist.is_initialized() and dist.get_rank() != 0:
return False
nprocs = len(dist_metas)
reader = io_backend.get_reader(path)
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
do_redist: bool = False
if len(dist_meta_list) == nprocs:
for a, b in zip(dist_metas, dist_meta_list):
if a != b:
do_redist = True
break
else:
do_redist = True
if not do_redist:
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
return False
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
writers[0].save_others(reader.load_others())
max_shard_size = int(max_shard_size_gb * 1024**3)
_convert_shards(
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
reader.load_models(), dist_meta_list)
_convert_shards(
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
for writer, dist_meta in zip(writers, dist_metas):
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
if param_to_os is not None:
meta_checkpoint['param_to_os'] = param_to_os
meta_checkpoint['paired_os'] = paired_os
writer.save_meta(meta_checkpoint)
return True
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
for shard_dict in shard_generator:
convertor.append(shard_dict, dist_meta_list)
convertor.complete()
def load(path: str,
model: Module,
optimizer: Optional[Optimizer] = None,
redist_meta: Optional[RedistMeta] = None,
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
max_shard_size_gb: float = 0.0,
backend: str = 'disk') -> dict:
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
rank: int = dist.get_rank() if dist.is_initialized() else 0
is_main_process: bool = rank == 0
# validate args
if redist_meta is None or dist_metas is None:
assert is_global
io_backend = get_backend(backend)
read_path: str = path
if is_main_process:
# pre-process checkpoints
temp_path = io_backend.get_temp(path)
if is_global:
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
else:
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
if wrote:
read_path = temp_path
if not is_global:
bcast_list = [read_path] if is_main_process else [None]
dist.broadcast_object_list(bcast_list)
read_path = bcast_list[0]
reader = io_backend.get_reader(read_path)
# load model
for shard in reader.load_model(rank):
model.load_state_dict(shard, strict=False)
if optimizer is not None:
for shard in reader.load_optimizer(rank):
# optimizer.load_state_dict(shard)
optimizer_load_state_dict(optimizer, shard)
others_dict = reader.load_others()
if not is_global:
dist.barrier()
# clean up temp
if is_main_process:
io_backend.clean_temp()
return others_dict

View File

@@ -0,0 +1,81 @@
from dataclasses import dataclass
from typing import List, Optional, Set, Dict
@dataclass
class ParamDistMeta:
# parallel info
dp_rank: int
dp_world_size: int
tp_rank: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_numel: Optional[int] = None
zero_orig_shape: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_numel is not None and self.zero_orig_shape is not None
@property
def parallel_meta(self) -> tuple:
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
@property
def tp_meta(self) -> tuple:
return self.tp_shard_dims, self.tp_num_parts
@property
def zero_meta(self) -> tuple:
return self.zero_numel, self.zero_orig_shape
@staticmethod
def from_dict(d: dict) -> 'ParamDistMeta':
return ParamDistMeta(**d)
@dataclass
class ParamRedistMeta:
# parallel info
dp_world_size: int
tp_world_size: int
# tp info
tp_shard_dims: Optional[List[int]] = None
tp_num_parts: Optional[List[int]] = None
# zero info
zero_start_dp_rank: Optional[int] = None
zero_offsets: Optional[List[int]] = None
@property
def used_tp(self) -> bool:
return self.tp_shard_dims is not None and self.tp_num_parts is not None
@property
def used_zero(self) -> bool:
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
@dataclass
class RankRedistMeta:
dp_rank: int
tp_rank: int
pp_rank: int
@dataclass
class PipelineRedistMeta:
params: Set[str]
@dataclass
class RedistMeta:
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
pipeline_meta: List[PipelineRedistMeta]
param_meta: Dict[str, ParamRedistMeta]

View File

@@ -0,0 +1,131 @@
import os
from abc import ABC, abstractmethod
from collections import Counter
from typing import Dict, Generator, List, Optional, Tuple
import torch
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
from .meta import ParamDistMeta
from .utils import is_duplicated_list
class CheckpointReader(ABC):
def __init__(self, base_name: str) -> None:
super().__init__()
self.base_name = base_name
self.meta_list = []
@abstractmethod
def read(self, name: str) -> dict:
pass
@abstractmethod
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
pass
@abstractmethod
def load_model(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_models(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
pass
@abstractmethod
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
pass
@abstractmethod
def load_others(self) -> dict:
pass
class DiskCheckpointReader(CheckpointReader):
def __init__(self, base_name: str) -> None:
super().__init__(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
global_meta = self.read(GLOBAL_META_FILE_NAME)
for meta_file_name in global_meta['meta']:
meta = self.read(meta_file_name)
if meta.get('dist_meta', None) is None:
# only global checkpoint can have empty dist_meta
assert len(global_meta['meta']) == 1
self.meta_list.append(meta)
def read(self, name: str) -> dict:
return torch.load(os.path.join(self.base_name, name))
def load_meta(
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
None), meta.get('paired_os', None))
for meta in self.meta_list]
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
# reduce param_count
param_count = Counter(p for params in params_list for p in params)
# validate param_to_os
assert is_duplicated_list(param_to_os_list)
assert is_duplicated_list(paired_os_list)
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
meta = self.meta_list[rank]
checkpoint_names = meta.get(shard_type, [])
for name in checkpoint_names:
yield self.read(name)
def load_model(self, rank: int) -> Generator[dict, None, None]:
return self._load_shard('model', rank)
def load_models(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
model_checkpoint_names = meta.get('model', [])
if indices[i] < len(model_checkpoint_names):
shards[i] = self.read(model_checkpoint_names[indices[i]])
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
param_groups = None
for shard in self._load_shard('optimizer', rank):
if param_groups is None:
param_groups = shard['param_groups']
else:
shard['param_groups'] = param_groups
yield shard
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
indices = [0] * len(self.meta_list)
param_groups = []
while True:
shards = {}
for i, meta in enumerate(self.meta_list):
optimizer_checkpoint_names = meta.get('optimizer', [])
if indices[i] < len(optimizer_checkpoint_names):
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
if indices[i] == 0:
param_groups.append(shards[i]['param_groups'])
else:
shards[i]['param_groups'] = param_groups[i]
indices[i] += 1
if len(shards) > 0:
yield shards
else:
break
def load_others(self) -> dict:
return self.read(OTHER_CKPT_FILE_NAME)

View File

@@ -0,0 +1,223 @@
import warnings
from copy import deepcopy
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from torch import Tensor
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from .meta import ParamDistMeta
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
if arg is not None:
return fn(arg)
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
# ensure all params in optimizer are in model state dict
params_set = set(id(p) for p in model.parameters())
for group in optimizer.param_groups:
for p in group['params']:
assert id(p) in params_set
param_mappings = {}
start_index = 0
def get_group_mapping(group):
nonlocal start_index
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
start_index += len(group['params'])
for g in optimizer.param_groups:
get_group_mapping(g)
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
size = 0
for v in state.values():
if isinstance(v, Tensor):
size += v.numel() * v.element_size()
return size
class ModelCheckpointSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, Tensor] = {}
self.buffer_size: int = 0
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {}
self.buffer_size = 0
self.buffer[key] = tensor
self.buffer_size += tensor.numel() * tensor.element_size()
return retval
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
shards = []
for key, tensor in state_dict.items():
shard = self.append(key, tensor)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer) > 0 else None
class OptimizerCheckpointSharder:
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
self.max_shard_size = max_shard_size
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
self.buffer_size: int = 0
self.returned_first: bool = False
def append(self, key: int, state: dict) -> Optional[dict]:
retval = None
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
retval = self.buffer
self.buffer = {'state': {}}
self.buffer_size = 0
self.buffer['state'][key] = state
self.buffer_size += compute_optimizer_state_size(state)
return retval
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
shards = []
for key, state in state_dict['state'].items():
shard = self.append(key, state)
run_if_not_none(shards.append, shard)
return shards
def complete(self) -> Optional[dict]:
return self.buffer if len(self.buffer['state']) > 0 else None
def shard_checkpoint(max_shard_size: int,
model_state_dict: Dict[str, Tensor],
optimizer_state_dict: Optional[dict] = None,
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
has_optimizer: bool = False
if optimizer_state_dict is not None:
assert param_to_os is not None
os_to_param = {v: k for k, v in param_to_os.items()}
for os_key in optimizer_state_dict['state'].keys():
assert os_key in os_to_param
assert os_to_param[os_key] in model_state_dict
has_optimizer = True
model_sharder = ModelCheckpointSharder(max_shard_size)
model_shards = model_sharder.extend(model_state_dict)
run_if_not_none(model_shards.append, model_sharder.complete())
if not has_optimizer:
return model_shards, []
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
return model_shards, optimizer_shards
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
os_to_param = {v: k for k, v in param_to_os.items()}
paired_os = {}
for idx, state in optimizer_state_dict['state'].items():
paired_os[idx] = {}
p = model_state_dict[os_to_param[idx]]
for k, v in state.items():
if isinstance(v, Tensor) and v.shape == p.shape:
paired_os[idx][k] = True
else:
paired_os[idx][k] = False
return paired_os
def build_checkpoints(max_size: int,
model: Module,
optimizer: Optional[Optimizer] = None,
param_to_os: Optional[Dict[str, int]] = None,
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
save_global = dist_meta is None
model_state_dict = model.state_dict()
optimizer_state_dict = optimizer.state_dict() if optimizer else None
meta = {'dist_meta': dist_meta}
if optimizer:
param_to_os = param_to_os or get_param_to_os(model, optimizer)
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
meta['param_to_os'] = param_to_os
meta['paired_os'] = paired_os
if not save_global and eliminate_replica:
# filter dp replicated params
model_state_dict = {
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
if optimizer:
optimizer_state_dict['state'] = {
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
for k in model_state_dict.keys()
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
}
meta['params'] = list(model_state_dict.keys())
if len(model_state_dict) == 0:
warnings.warn('model state dict is empty, checkpoint is not saved')
return [], [], meta
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
param_to_os)
return model_checkpoints, optimizer_checkpoints, meta
def is_duplicated_list(list_: List[Any]) -> bool:
if len(list_) == 0:
return True
elem = list_[0]
for x in list_[1:]:
if x != elem:
return False
return True
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
for k, v in src_state.items():
if k in dest_state:
old_v = dest_state[k]
if isinstance(old_v, Tensor):
old_v.copy_(v)
else:
dest_state[k] = v
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
state_dict = deepcopy(state_dict)
groups = optimizer.param_groups
saved_groups = state_dict['param_groups']
idx_to_p: Dict[str, Parameter] = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
unexpected_keys = []
error_msgs = []
for idx, state in state_dict['state'].items():
if idx in idx_to_p:
old_state = optimizer.state[idx_to_p[idx]]
copy_optimizer_state(state, old_state)
else:
unexpected_keys.append(idx)
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
if len(missing_keys) > 0:
error_msgs.insert(
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
"\n\t".join(error_msgs)))

View File

@@ -0,0 +1,98 @@
from abc import ABC, abstractmethod
from typing import Optional
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
import torch
import os
class CheckpointWriter(ABC):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__()
self.base_name = base_name
self.overwrite = overwrite
self.rank = rank
self.world_size = world_size
self.is_distributed = world_size > 1
self.is_main_process = rank == 0
@abstractmethod
def write(self, name: str, state_dict: dict) -> None:
pass
@abstractmethod
def save_model(self, model_checkpoint: dict) -> None:
pass
@abstractmethod
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
pass
@abstractmethod
def save_meta(self, meta_checkpoint: dict) -> None:
pass
@abstractmethod
def save_others(self, kwargs: dict) -> None:
pass
class DiskCheckpointWriter(CheckpointWriter):
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
super().__init__(base_name, overwrite, rank, world_size)
if not os.path.exists(base_name):
os.makedirs(base_name)
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
self.model_checkpoint_names = []
self.optimizer_checkpoint_names = []
self.is_meta_saved: bool = False
self._save_global_meta()
def write(self, name: str, state_dict: dict) -> None:
path = os.path.join(self.base_name, name)
if os.path.exists(path) and not self.overwrite:
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
torch.save(state_dict, path)
def _save_global_meta(self) -> None:
if self.is_main_process:
global_meta = {'meta': []}
if self.is_distributed:
for i in range(self.world_size):
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
else:
global_meta['meta'].append(META_CKPT_FILE_NAME)
self.write(GLOBAL_META_FILE_NAME, global_meta)
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
checkpoint_name = base_name
if self.is_distributed:
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
if shard_idx is not None:
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
return checkpoint_name
def save_model(self, model_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save model after saving meta'
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
self.write(name, model_checkpoint)
self.model_checkpoint_names.append(name)
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
self.write(name, optimizer_checkpoint)
self.optimizer_checkpoint_names.append(name)
def save_meta(self, meta_checkpoint: dict) -> None:
if len(self.model_checkpoint_names) > 0:
meta_checkpoint['model'] = self.model_checkpoint_names
if len(self.optimizer_checkpoint_names) > 0:
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
self.is_meta_saved = True
def save_others(self, kwargs: dict) -> None:
if self.is_main_process:
self.write(OTHER_CKPT_FILE_NAME, kwargs)