mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
[CheckpointIO] a uniform checkpoint I/O module (#1689)
This commit is contained in:
2
colossalai/utils/checkpoint_io/__init__.py
Normal file
2
colossalai/utils/checkpoint_io/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .io import load, merge, redist, save
|
||||
from .meta import (ParamDistMeta, ParamRedistMeta, PipelineRedistMeta, RankRedistMeta, RedistMeta)
|
74
colossalai/utils/checkpoint_io/backend.py
Normal file
74
colossalai/utils/checkpoint_io/backend.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import shutil
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Type
|
||||
|
||||
from .reader import CheckpointReader, DiskCheckpointReader
|
||||
from .writer import CheckpointWriter, DiskCheckpointWriter
|
||||
|
||||
_backends: Dict[str, Type['CheckpointIOBackend']] = {}
|
||||
|
||||
|
||||
def register(name: str):
|
||||
assert name not in _backends, f'"{name}" is registered'
|
||||
|
||||
def wrapper(cls):
|
||||
_backends[name] = cls
|
||||
return cls
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_backend(name: str) -> 'CheckpointIOBackend':
|
||||
assert name in _backends, f'Unsupported backend "{name}"'
|
||||
return _backends[name]()
|
||||
|
||||
|
||||
class CheckpointIOBackend(ABC):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.temps: List[str] = []
|
||||
|
||||
@abstractmethod
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clean_temp(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
@register('disk')
|
||||
class CheckpointDiskIO(CheckpointIOBackend):
|
||||
|
||||
def get_writer(self,
|
||||
base_name: str,
|
||||
overwrite: bool = False,
|
||||
rank: int = 0,
|
||||
world_size: int = 1) -> CheckpointWriter:
|
||||
return DiskCheckpointWriter(base_name, overwrite, rank=rank, world_size=world_size)
|
||||
|
||||
def get_reader(self, base_name: str) -> CheckpointReader:
|
||||
return DiskCheckpointReader(base_name)
|
||||
|
||||
def get_temp(self, base_name: str) -> str:
|
||||
temp_dir_name = tempfile.mkdtemp(dir=base_name)
|
||||
self.temps.append(temp_dir_name)
|
||||
return temp_dir_name
|
||||
|
||||
def clean_temp(self) -> None:
|
||||
for temp_dir_name in self.temps:
|
||||
shutil.rmtree(temp_dir_name)
|
9
colossalai/utils/checkpoint_io/constant.py
Normal file
9
colossalai/utils/checkpoint_io/constant.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import re
|
||||
|
||||
GLOBAL_META_FILE_NAME = 'global_meta.bin'
|
||||
MODEL_CKPT_FILE_NAME = 'model.bin'
|
||||
OPTIM_CKPT_FILE_NAME = 'optim.bin'
|
||||
META_CKPT_FILE_NAME = 'meta.bin'
|
||||
OTHER_CKPT_FILE_NAME = 'other.bin'
|
||||
|
||||
CKPT_PAT = re.compile(r'global_meta|model|optim|meta|other')
|
227
colossalai/utils/checkpoint_io/convertor.py
Normal file
227
colossalai/utils/checkpoint_io/convertor.py
Normal file
@@ -0,0 +1,227 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from .distributed import merge_param, unmerge_param
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import (ModelCheckpointSharder, OptimizerCheckpointSharder, run_if_not_none)
|
||||
|
||||
|
||||
class CheckpointConvertor(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def complete(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ModelCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.buffer: Dict[str, Dict[int, Tensor]] = defaultdict(dict)
|
||||
|
||||
@abstractmethod
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
for k, tensor in state_dict.items():
|
||||
self.buffer[k][rank] = tensor
|
||||
converted_keys = set()
|
||||
for k, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[k]:
|
||||
tensors = []
|
||||
dist_metas = []
|
||||
for rank, tensor in rank_dict.items():
|
||||
tensors.append(tensor)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][k])
|
||||
self.convert_tensors(k, tensors, dist_metas)
|
||||
converted_keys.add(k)
|
||||
for k in converted_keys:
|
||||
del self.buffer[k]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class ModelCheckpointMerger(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int]) -> None:
|
||||
super().__init__(param_count)
|
||||
self.sharder = ModelCheckpointSharder(max_shard_size)
|
||||
self.save_fn = save_fn
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
shard = self.sharder.append(key, tensor)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class ModelCheckpointRedistor(ModelCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count)
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
nprocs = len(save_fns)
|
||||
self.sharders = [ModelCheckpointSharder(max_shard_size) for _ in range(nprocs)]
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def convert_tensors(self, key: str, tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> None:
|
||||
if len(dist_metas) == 0:
|
||||
# already global
|
||||
tensor = tensors[0]
|
||||
else:
|
||||
assert len(dist_metas) == len(tensors)
|
||||
tensor = merge_param(tensors, dist_metas)
|
||||
for tp_rank, tensor_list in enumerate(unmerge_param(tensor, self.redist_meta.param_meta[key])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[key][tp_rank][dp_rank]:
|
||||
shard = self.sharders[rank].append(key, t)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointConvertor(CheckpointConvertor):
|
||||
|
||||
def __init__(self, param_count: Dict[str, int], param_to_os: Optional[Dict[str, int]],
|
||||
paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__()
|
||||
self.param_count = param_count
|
||||
self.param_to_os = param_to_os
|
||||
self.paired_os = paired_os
|
||||
self.buffer: Dict[int, Dict[int, dict]] = defaultdict(dict)
|
||||
self.os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
|
||||
@abstractmethod
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
pass
|
||||
|
||||
def append(self, shard_dict: Dict[int, dict], dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for rank, state_dict in shard_dict.items():
|
||||
self.setup(state_dict['param_groups'])
|
||||
for idx, state in state_dict['state'].items():
|
||||
self.buffer[idx][rank] = state
|
||||
converted_indices = set()
|
||||
for idx, rank_dict in self.buffer.items():
|
||||
if len(rank_dict) == self.param_count[self.os_to_param[idx]]:
|
||||
states = []
|
||||
dist_metas = []
|
||||
for rank, state in rank_dict.items():
|
||||
states.append(state)
|
||||
if dist_meta_list[rank] is not None:
|
||||
dist_metas.append(dist_meta_list[rank][self.os_to_param[idx]])
|
||||
self.convert_states(idx, states, dist_metas)
|
||||
converted_indices.add(idx)
|
||||
for idx in converted_indices:
|
||||
del self.buffer[idx]
|
||||
|
||||
def complete(self) -> None:
|
||||
assert len(self.buffer) == 0
|
||||
|
||||
|
||||
class OptimizerCheckpointMerger(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fn: Callable[[dict], Any], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]]) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fn = save_fn
|
||||
self.sharder = None
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if self.sharder is None:
|
||||
self.sharder = OptimizerCheckpointSharder(self.max_shard_size, param_groups)
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_state = {}
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
new_state[state_key] = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
new_state[state_key] = state_tensor
|
||||
shard = self.sharder.append(idx, new_state)
|
||||
run_if_not_none(self.save_fn, shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
run_if_not_none(self.save_fn, self.sharder.complete())
|
||||
|
||||
|
||||
class OptimizerCheckpointRedistor(OptimizerCheckpointConvertor):
|
||||
|
||||
def __init__(self, max_shard_size: int, save_fns: List[Callable[[dict], Any]], param_count: Dict[str, int],
|
||||
param_to_os: Optional[Dict[str, int]], paired_os: Optional[Dict[int, dict]],
|
||||
redist_meta: RedistMeta) -> None:
|
||||
super().__init__(param_count, param_to_os, paired_os)
|
||||
self.max_shard_size = max_shard_size
|
||||
self.save_fns = save_fns
|
||||
self.redist_meta = redist_meta
|
||||
self.sharders: List[OptimizerCheckpointSharder] = []
|
||||
self.rank_map = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
||||
for k, rank_meta in redist_meta.rank_meta.items():
|
||||
for rank, rank_info in rank_meta.items():
|
||||
self.rank_map[k][rank_info.tp_rank][rank_info.dp_rank].append(rank)
|
||||
|
||||
def setup(self, param_groups: dict) -> None:
|
||||
if len(self.sharders) == 0:
|
||||
nprocs = len(self.save_fns)
|
||||
for _ in range(nprocs):
|
||||
self.sharders.append(OptimizerCheckpointSharder(self.max_shard_size, param_groups))
|
||||
|
||||
def convert_states(self, idx: int, states: List[dict], dist_metas: List[ParamDistMeta]) -> None:
|
||||
need_merge: bool = True
|
||||
if len(dist_metas) == 0:
|
||||
need_merge = False
|
||||
else:
|
||||
assert len(dist_metas) == len(states)
|
||||
new_states = [{} for _ in range(len(self.save_fns))]
|
||||
for state_key, state_tensor in states[0].items():
|
||||
if self.paired_os[idx][state_key]:
|
||||
if need_merge:
|
||||
tensor = merge_param([state[state_key] for state in states], dist_metas)
|
||||
else:
|
||||
tensor = state_tensor
|
||||
for tp_rank, tensor_list in enumerate(
|
||||
unmerge_param(tensor, self.redist_meta.param_meta[self.os_to_param[idx]])):
|
||||
for dp_rank, t in enumerate(tensor_list):
|
||||
for rank in self.rank_map[self.os_to_param[idx]][tp_rank][dp_rank]:
|
||||
new_states[rank][state_key] = t
|
||||
else:
|
||||
for new_state in new_states:
|
||||
new_state[state_key] = state_tensor
|
||||
for rank, new_state in enumerate(new_states):
|
||||
shard = self.sharders[rank].append(idx, new_state)
|
||||
run_if_not_none(self.save_fns[rank], shard)
|
||||
|
||||
def complete(self) -> None:
|
||||
super().complete()
|
||||
for rank, save_fn in enumerate(self.save_fns):
|
||||
run_if_not_none(save_fn, self.sharders[rank].complete())
|
127
colossalai/utils/checkpoint_io/distributed.py
Normal file
127
colossalai/utils/checkpoint_io/distributed.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import torch
|
||||
from numpy import prod
|
||||
from torch import Tensor
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import defaultdict
|
||||
from .meta import ParamDistMeta, ParamRedistMeta
|
||||
|
||||
|
||||
def unflatten_zero_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.zero_meta == dist_metas[0].zero_meta, 'Expect all params have the same zero meta.'
|
||||
if not dist_metas[0].used_zero:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
numel = dist_metas[0].zero_numel
|
||||
orig_shape = dist_metas[0].zero_orig_shape
|
||||
tensors = [t[1] for t in sorted(zip(dist_metas, tensors), key=lambda tp: tp[0].dp_rank)]
|
||||
assert numel == sum(t.numel() for t in tensors), 'Expect numel of all params is equal to zero_numel.'
|
||||
return torch.cat(tensors).reshape(orig_shape)
|
||||
|
||||
|
||||
def gather_tp_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.tp_meta == dist_metas[0].tp_meta, 'Expect all params have the same tp meta.'
|
||||
for t in tensors[1:]:
|
||||
assert t.shape == tensors[0].shape, 'Expect all params have the same shape.'
|
||||
if not dist_metas[0].used_tp:
|
||||
# tensors are replicate
|
||||
return tensors[0]
|
||||
total_parts = prod(dist_meta.tp_num_parts)
|
||||
assert dist_meta.tp_world_size == total_parts, \
|
||||
f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {dist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(dist_meta.tp_shard_dims, dist_meta.tp_num_parts), key=lambda t: t[0], reverse=True)
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for start in range(0, len(tensors), num_parts):
|
||||
buffer.append(torch.cat(tensors[start:start + num_parts], dim))
|
||||
tensors = buffer
|
||||
assert len(tensors) == 1
|
||||
return tensors[0]
|
||||
|
||||
|
||||
def validate_parallel_info(dist_metas: List[ParamDistMeta]) -> None:
|
||||
assert len(dist_metas) > 0
|
||||
# check world size
|
||||
for dist_meta in dist_metas[1:]:
|
||||
assert dist_meta.dp_world_size == dist_metas[
|
||||
0].dp_world_size, 'Expect all dist meta have the same dp_world_size'
|
||||
assert dist_meta.tp_world_size == dist_metas[
|
||||
0].tp_world_size, 'Expect all dist meta have the same tp_world_size'
|
||||
|
||||
|
||||
def deduplicate_params(tensors: List[Tensor],
|
||||
dist_metas: List[ParamDistMeta]) -> Tuple[List[Tensor], List[ParamDistMeta]]:
|
||||
unique_dist_meta = []
|
||||
unique_idx = []
|
||||
for i, dist_meta in enumerate(dist_metas):
|
||||
if dist_meta not in unique_dist_meta:
|
||||
unique_dist_meta.append(dist_meta)
|
||||
unique_idx.append(i)
|
||||
return [tensors[i] for i in unique_idx], [dist_metas[i] for i in unique_idx]
|
||||
|
||||
|
||||
def merge_param(tensors: List[Tensor], dist_metas: List[ParamDistMeta]) -> Tensor:
|
||||
assert len(tensors) > 0 and len(dist_metas) > 0 and len(tensors) == len(dist_metas)
|
||||
# validate parallel info
|
||||
validate_parallel_info(dist_metas)
|
||||
tensors, dist_metas = deduplicate_params(tensors, dist_metas)
|
||||
unflattened_tensors = []
|
||||
# group zero params by tp rank
|
||||
tensor_dict = defaultdict(list)
|
||||
dist_meta_dict = defaultdict(list)
|
||||
for t, dist_meta in zip(tensors, dist_metas):
|
||||
tensor_dict[dist_meta.tp_rank].append(t)
|
||||
dist_meta_dict[dist_meta.tp_rank].append(dist_meta)
|
||||
assert len(tensor_dict
|
||||
) == dist_metas[0].tp_world_size, f'Expect {dist_metas[0].tp_world_size} ranks, got {len(tensor_dict)}'
|
||||
for tp_rank in tensor_dict.keys():
|
||||
unflattened_tensors.append(unflatten_zero_param(tensor_dict[tp_rank], dist_meta_dict[tp_rank]))
|
||||
return gather_tp_param(unflattened_tensors, [dist_meta_list[0] for dist_meta_list in dist_meta_dict.values()])
|
||||
|
||||
|
||||
def split_tp_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_tp:
|
||||
assert redist_meta.tp_world_size == 1, 'Expect tp_world_size == 1, when no tp meta provided.'
|
||||
return [tensor]
|
||||
total_parts = prod(redist_meta.tp_num_parts)
|
||||
assert redist_meta.tp_world_size == total_parts, f'Expect prod(tp_num_parts) == tp_world_size, got {total_parts} and {redist_meta.tp_world_size}.'
|
||||
shard_info = sorted(zip(redist_meta.tp_shard_dims, redist_meta.tp_num_parts), key=lambda t: t[0])
|
||||
tensors = [tensor]
|
||||
for dim, num_parts in shard_info:
|
||||
buffer = []
|
||||
for t in tensors:
|
||||
assert t.size(dim) % num_parts == 0, \
|
||||
f'Expect dim{dim} of tensor({tensor.shape}) is divisible by {num_parts}.'
|
||||
chunks = [chunk.contiguous() for chunk in t.chunk(num_parts, dim)]
|
||||
buffer.extend(chunks)
|
||||
tensors = buffer
|
||||
assert len(tensors) == redist_meta.tp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def flatten_zero_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[Tensor]:
|
||||
if not redist_meta.used_zero:
|
||||
return [tensor] * redist_meta.dp_world_size
|
||||
tensors: List[Optional[Tensor]] = [
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device) for _ in range(redist_meta.zero_start_dp_rank)
|
||||
]
|
||||
offsets = redist_meta.zero_offsets + [tensor.numel()]
|
||||
for i, offset in enumerate(offsets[:-1]):
|
||||
end = offsets[i + 1]
|
||||
tensors.append(tensor.view(-1)[offset:end])
|
||||
if len(tensors) < redist_meta.dp_world_size:
|
||||
tensors.extend([
|
||||
torch.empty(0, dtype=tensor.dtype, device=tensor.device)
|
||||
for _ in range(redist_meta.dp_world_size - len(tensors))
|
||||
])
|
||||
assert len(tensors) == redist_meta.dp_world_size
|
||||
return tensors
|
||||
|
||||
|
||||
def unmerge_param(tensor: Tensor, redist_meta: ParamRedistMeta) -> List[List[Tensor]]:
|
||||
tensors = split_tp_param(tensor, redist_meta)
|
||||
tensors = [flatten_zero_param(t, redist_meta) for t in tensors]
|
||||
return tensors
|
170
colossalai/utils/checkpoint_io/io.py
Normal file
170
colossalai/utils/checkpoint_io/io.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Module
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .backend import get_backend
|
||||
from .convertor import (CheckpointConvertor, ModelCheckpointMerger, ModelCheckpointRedistor, OptimizerCheckpointMerger,
|
||||
OptimizerCheckpointRedistor)
|
||||
from .meta import ParamDistMeta, RedistMeta
|
||||
from .utils import build_checkpoints, optimizer_load_state_dict
|
||||
|
||||
|
||||
def save(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk',
|
||||
**kwargs: Any) -> None:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
if world_size == 1:
|
||||
# global doesn't need dist_meta
|
||||
dist_meta = None
|
||||
else:
|
||||
assert dist_meta is not None
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
model_checkpoints, optimizer_checkpoints, meta_checkpoint = build_checkpoints(max_shard_size, model, optimizer,
|
||||
param_to_os, dist_meta)
|
||||
writer = io_backend.get_writer(path, overwrite, rank, world_size)
|
||||
writer.save_others(kwargs)
|
||||
for model_checkpoint in model_checkpoints:
|
||||
writer.save_model(model_checkpoint)
|
||||
for optimizer_checkpoint in optimizer_checkpoints:
|
||||
writer.save_optimizer(optimizer_checkpoint)
|
||||
writer.save_meta(meta_checkpoint)
|
||||
|
||||
|
||||
def merge(path: str,
|
||||
output_path: str,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
reader = io_backend.get_reader(path)
|
||||
if len(reader.meta_list) == 1:
|
||||
# already global
|
||||
warnings.warn(f'Checkpoint at "{path}" is already global, nothing to do.')
|
||||
return False
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
writer = io_backend.get_writer(output_path, overwrite=overwrite)
|
||||
writer.save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(ModelCheckpointMerger(max_shard_size, writer.save_model, param_count), reader.load_models(),
|
||||
dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointMerger(max_shard_size, writer.save_optimizer, param_count, param_to_os, paired_os),
|
||||
reader.load_optimizers(), dist_meta_list)
|
||||
meta_checkpoint = {'dist_meta': None, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def redist(path: str,
|
||||
output_path: str,
|
||||
redist_meta: RedistMeta,
|
||||
dist_metas: List[Dict[str, ParamDistMeta]],
|
||||
max_shard_size_gb: float = 0.0,
|
||||
overwrite: bool = False,
|
||||
backend: str = 'disk') -> bool:
|
||||
io_backend = get_backend(backend)
|
||||
if dist.is_initialized() and dist.get_rank() != 0:
|
||||
return False
|
||||
nprocs = len(dist_metas)
|
||||
reader = io_backend.get_reader(path)
|
||||
dist_meta_list, param_count, param_to_os, paired_os = reader.load_meta()
|
||||
do_redist: bool = False
|
||||
if len(dist_meta_list) == nprocs:
|
||||
for a, b in zip(dist_metas, dist_meta_list):
|
||||
if a != b:
|
||||
do_redist = True
|
||||
break
|
||||
else:
|
||||
do_redist = True
|
||||
if not do_redist:
|
||||
warnings.warn(f'Checkpoint at "{path}" is not required to redist, nothing to do.')
|
||||
return False
|
||||
|
||||
writers = [io_backend.get_writer(output_path, overwrite, rank, nprocs) for rank in range(nprocs)]
|
||||
writers[0].save_others(reader.load_others())
|
||||
max_shard_size = int(max_shard_size_gb * 1024**3)
|
||||
_convert_shards(
|
||||
ModelCheckpointRedistor(max_shard_size, [writer.save_model for writer in writers], param_count, redist_meta),
|
||||
reader.load_models(), dist_meta_list)
|
||||
_convert_shards(
|
||||
OptimizerCheckpointRedistor(max_shard_size, [writer.save_optimizer for writer in writers], param_count,
|
||||
param_to_os, paired_os, redist_meta), reader.load_optimizers(), dist_meta_list)
|
||||
for writer, dist_meta in zip(writers, dist_metas):
|
||||
meta_checkpoint = {'dist_meta': dist_meta, 'params': list(param_count.keys())}
|
||||
if param_to_os is not None:
|
||||
meta_checkpoint['param_to_os'] = param_to_os
|
||||
meta_checkpoint['paired_os'] = paired_os
|
||||
writer.save_meta(meta_checkpoint)
|
||||
return True
|
||||
|
||||
|
||||
def _convert_shards(convertor: CheckpointConvertor, shard_generator: Generator[dict, None, None],
|
||||
dist_meta_list: List[Optional[Dict[str, ParamDistMeta]]]) -> None:
|
||||
for shard_dict in shard_generator:
|
||||
convertor.append(shard_dict, dist_meta_list)
|
||||
convertor.complete()
|
||||
|
||||
|
||||
def load(path: str,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
redist_meta: Optional[RedistMeta] = None,
|
||||
dist_metas: Optional[List[Dict[str, ParamDistMeta]]] = None,
|
||||
max_shard_size_gb: float = 0.0,
|
||||
backend: str = 'disk') -> dict:
|
||||
is_global: bool = not dist.is_initialized() or dist.get_world_size() == 1
|
||||
rank: int = dist.get_rank() if dist.is_initialized() else 0
|
||||
is_main_process: bool = rank == 0
|
||||
# validate args
|
||||
if redist_meta is None or dist_metas is None:
|
||||
assert is_global
|
||||
io_backend = get_backend(backend)
|
||||
read_path: str = path
|
||||
if is_main_process:
|
||||
# pre-process checkpoints
|
||||
temp_path = io_backend.get_temp(path)
|
||||
if is_global:
|
||||
wrote = merge(path, temp_path, max_shard_size_gb, backend=backend)
|
||||
else:
|
||||
wrote = redist(path, temp_path, redist_meta, dist_metas, max_shard_size_gb, backend=backend)
|
||||
if wrote:
|
||||
read_path = temp_path
|
||||
if not is_global:
|
||||
bcast_list = [read_path] if is_main_process else [None]
|
||||
dist.broadcast_object_list(bcast_list)
|
||||
read_path = bcast_list[0]
|
||||
reader = io_backend.get_reader(read_path)
|
||||
# load model
|
||||
for shard in reader.load_model(rank):
|
||||
model.load_state_dict(shard, strict=False)
|
||||
if optimizer is not None:
|
||||
for shard in reader.load_optimizer(rank):
|
||||
# optimizer.load_state_dict(shard)
|
||||
optimizer_load_state_dict(optimizer, shard)
|
||||
others_dict = reader.load_others()
|
||||
if not is_global:
|
||||
dist.barrier()
|
||||
# clean up temp
|
||||
if is_main_process:
|
||||
io_backend.clean_temp()
|
||||
return others_dict
|
81
colossalai/utils/checkpoint_io/meta.py
Normal file
81
colossalai/utils/checkpoint_io/meta.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Set, Dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamDistMeta:
|
||||
# parallel info
|
||||
dp_rank: int
|
||||
dp_world_size: int
|
||||
tp_rank: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_numel: Optional[int] = None
|
||||
zero_orig_shape: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_numel is not None and self.zero_orig_shape is not None
|
||||
|
||||
@property
|
||||
def parallel_meta(self) -> tuple:
|
||||
return self.dp_rank, self.dp_world_size, self.tp_rank, self.tp_world_size
|
||||
|
||||
@property
|
||||
def tp_meta(self) -> tuple:
|
||||
return self.tp_shard_dims, self.tp_num_parts
|
||||
|
||||
@property
|
||||
def zero_meta(self) -> tuple:
|
||||
return self.zero_numel, self.zero_orig_shape
|
||||
|
||||
@staticmethod
|
||||
def from_dict(d: dict) -> 'ParamDistMeta':
|
||||
return ParamDistMeta(**d)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamRedistMeta:
|
||||
# parallel info
|
||||
dp_world_size: int
|
||||
tp_world_size: int
|
||||
# tp info
|
||||
tp_shard_dims: Optional[List[int]] = None
|
||||
tp_num_parts: Optional[List[int]] = None
|
||||
# zero info
|
||||
zero_start_dp_rank: Optional[int] = None
|
||||
zero_offsets: Optional[List[int]] = None
|
||||
|
||||
@property
|
||||
def used_tp(self) -> bool:
|
||||
return self.tp_shard_dims is not None and self.tp_num_parts is not None
|
||||
|
||||
@property
|
||||
def used_zero(self) -> bool:
|
||||
return self.zero_start_dp_rank is not None and self.zero_offsets is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankRedistMeta:
|
||||
dp_rank: int
|
||||
tp_rank: int
|
||||
pp_rank: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRedistMeta:
|
||||
params: Set[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RedistMeta:
|
||||
rank_meta: Dict[str, Dict[int, RankRedistMeta]]
|
||||
pipeline_meta: List[PipelineRedistMeta]
|
||||
param_meta: Dict[str, ParamRedistMeta]
|
131
colossalai/utils/checkpoint_io/reader.py
Normal file
131
colossalai/utils/checkpoint_io/reader.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from typing import Dict, Generator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .constant import GLOBAL_META_FILE_NAME, OTHER_CKPT_FILE_NAME
|
||||
from .meta import ParamDistMeta
|
||||
from .utils import is_duplicated_list
|
||||
|
||||
|
||||
class CheckpointReader(ABC):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.meta_list = []
|
||||
|
||||
@abstractmethod
|
||||
def read(self, name: str) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_others(self) -> dict:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointReader(CheckpointReader):
|
||||
|
||||
def __init__(self, base_name: str) -> None:
|
||||
super().__init__(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
global_meta = self.read(GLOBAL_META_FILE_NAME)
|
||||
for meta_file_name in global_meta['meta']:
|
||||
meta = self.read(meta_file_name)
|
||||
if meta.get('dist_meta', None) is None:
|
||||
# only global checkpoint can have empty dist_meta
|
||||
assert len(global_meta['meta']) == 1
|
||||
self.meta_list.append(meta)
|
||||
|
||||
def read(self, name: str) -> dict:
|
||||
return torch.load(os.path.join(self.base_name, name))
|
||||
|
||||
def load_meta(
|
||||
self) -> Tuple[List[Optional[Dict[str, ParamDistMeta]]], Dict[str, int], Optional[dict], Optional[dict]]:
|
||||
meta_infos = [(meta.get('dist_meta', None), meta['params'], meta.get('param_to_os',
|
||||
None), meta.get('paired_os', None))
|
||||
for meta in self.meta_list]
|
||||
dist_meta_list, params_list, param_to_os_list, paired_os_list = zip(*meta_infos)
|
||||
# reduce param_count
|
||||
param_count = Counter(p for params in params_list for p in params)
|
||||
# validate param_to_os
|
||||
assert is_duplicated_list(param_to_os_list)
|
||||
assert is_duplicated_list(paired_os_list)
|
||||
return list(dist_meta_list), param_count, param_to_os_list[0], paired_os_list[0]
|
||||
|
||||
def _load_shard(self, shard_type: str, rank: int) -> Generator[dict, None, None]:
|
||||
meta = self.meta_list[rank]
|
||||
checkpoint_names = meta.get(shard_type, [])
|
||||
for name in checkpoint_names:
|
||||
yield self.read(name)
|
||||
|
||||
def load_model(self, rank: int) -> Generator[dict, None, None]:
|
||||
return self._load_shard('model', rank)
|
||||
|
||||
def load_models(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
model_checkpoint_names = meta.get('model', [])
|
||||
if indices[i] < len(model_checkpoint_names):
|
||||
shards[i] = self.read(model_checkpoint_names[indices[i]])
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_optimizer(self, rank: int) -> Generator[dict, None, None]:
|
||||
param_groups = None
|
||||
for shard in self._load_shard('optimizer', rank):
|
||||
if param_groups is None:
|
||||
param_groups = shard['param_groups']
|
||||
else:
|
||||
shard['param_groups'] = param_groups
|
||||
yield shard
|
||||
|
||||
def load_optimizers(self) -> Generator[Dict[int, dict], None, None]:
|
||||
indices = [0] * len(self.meta_list)
|
||||
param_groups = []
|
||||
while True:
|
||||
shards = {}
|
||||
for i, meta in enumerate(self.meta_list):
|
||||
optimizer_checkpoint_names = meta.get('optimizer', [])
|
||||
if indices[i] < len(optimizer_checkpoint_names):
|
||||
shards[i] = self.read(optimizer_checkpoint_names[indices[i]])
|
||||
if indices[i] == 0:
|
||||
param_groups.append(shards[i]['param_groups'])
|
||||
else:
|
||||
shards[i]['param_groups'] = param_groups[i]
|
||||
indices[i] += 1
|
||||
if len(shards) > 0:
|
||||
yield shards
|
||||
else:
|
||||
break
|
||||
|
||||
def load_others(self) -> dict:
|
||||
return self.read(OTHER_CKPT_FILE_NAME)
|
223
colossalai/utils/checkpoint_io/utils.py
Normal file
223
colossalai/utils/checkpoint_io/utils.py
Normal file
@@ -0,0 +1,223 @@
|
||||
import warnings
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .meta import ParamDistMeta
|
||||
|
||||
|
||||
def run_if_not_none(fn: Callable[[Any], Any], arg: Any) -> Any:
|
||||
if arg is not None:
|
||||
return fn(arg)
|
||||
|
||||
|
||||
def get_param_to_os(model: Module, optimizer: Optimizer) -> Dict[str, int]:
|
||||
# ensure all params in optimizer are in model state dict
|
||||
params_set = set(id(p) for p in model.parameters())
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
assert id(p) in params_set
|
||||
param_mappings = {}
|
||||
start_index = 0
|
||||
|
||||
def get_group_mapping(group):
|
||||
nonlocal start_index
|
||||
param_mappings.update(
|
||||
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
|
||||
start_index += len(group['params'])
|
||||
|
||||
for g in optimizer.param_groups:
|
||||
get_group_mapping(g)
|
||||
return {k: param_mappings[id(p)] for k, p in model.named_parameters()}
|
||||
|
||||
|
||||
def compute_optimizer_state_size(state: Dict[str, Any]) -> int:
|
||||
size = 0
|
||||
for v in state.values():
|
||||
if isinstance(v, Tensor):
|
||||
size += v.numel() * v.element_size()
|
||||
return size
|
||||
|
||||
|
||||
class ModelCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, Tensor] = {}
|
||||
self.buffer_size: int = 0
|
||||
|
||||
def append(self, key: str, tensor: Tensor) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {}
|
||||
self.buffer_size = 0
|
||||
self.buffer[key] = tensor
|
||||
self.buffer_size += tensor.numel() * tensor.element_size()
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, Tensor]) -> List[dict]:
|
||||
shards = []
|
||||
for key, tensor in state_dict.items():
|
||||
shard = self.append(key, tensor)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer) > 0 else None
|
||||
|
||||
|
||||
class OptimizerCheckpointSharder:
|
||||
|
||||
def __init__(self, max_shard_size: int, param_groups: dict) -> None:
|
||||
self.max_shard_size = max_shard_size
|
||||
self.buffer: Dict[str, dict] = {'state': {}, 'param_groups': param_groups}
|
||||
self.buffer_size: int = 0
|
||||
self.returned_first: bool = False
|
||||
|
||||
def append(self, key: int, state: dict) -> Optional[dict]:
|
||||
retval = None
|
||||
if self.max_shard_size > 0 and self.buffer_size >= self.max_shard_size:
|
||||
retval = self.buffer
|
||||
self.buffer = {'state': {}}
|
||||
self.buffer_size = 0
|
||||
self.buffer['state'][key] = state
|
||||
self.buffer_size += compute_optimizer_state_size(state)
|
||||
return retval
|
||||
|
||||
def extend(self, state_dict: Dict[str, dict]) -> List[dict]:
|
||||
shards = []
|
||||
for key, state in state_dict['state'].items():
|
||||
shard = self.append(key, state)
|
||||
run_if_not_none(shards.append, shard)
|
||||
return shards
|
||||
|
||||
def complete(self) -> Optional[dict]:
|
||||
return self.buffer if len(self.buffer['state']) > 0 else None
|
||||
|
||||
|
||||
def shard_checkpoint(max_shard_size: int,
|
||||
model_state_dict: Dict[str, Tensor],
|
||||
optimizer_state_dict: Optional[dict] = None,
|
||||
param_to_os: Optional[dict] = None) -> Tuple[List[dict], List[dict]]:
|
||||
has_optimizer: bool = False
|
||||
if optimizer_state_dict is not None:
|
||||
assert param_to_os is not None
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
for os_key in optimizer_state_dict['state'].keys():
|
||||
assert os_key in os_to_param
|
||||
assert os_to_param[os_key] in model_state_dict
|
||||
has_optimizer = True
|
||||
model_sharder = ModelCheckpointSharder(max_shard_size)
|
||||
model_shards = model_sharder.extend(model_state_dict)
|
||||
run_if_not_none(model_shards.append, model_sharder.complete())
|
||||
if not has_optimizer:
|
||||
return model_shards, []
|
||||
optimizer_sharder = OptimizerCheckpointSharder(max_shard_size, optimizer_state_dict['param_groups'])
|
||||
optimizer_shards = optimizer_sharder.extend(optimizer_state_dict)
|
||||
run_if_not_none(optimizer_shards.append, optimizer_sharder.complete())
|
||||
return model_shards, optimizer_shards
|
||||
|
||||
|
||||
def get_paired_os(model_state_dict: Dict[str, Tensor], optimizer_state_dict: dict, param_to_os: Dict[str, int]) -> dict:
|
||||
os_to_param = {v: k for k, v in param_to_os.items()}
|
||||
paired_os = {}
|
||||
for idx, state in optimizer_state_dict['state'].items():
|
||||
paired_os[idx] = {}
|
||||
p = model_state_dict[os_to_param[idx]]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, Tensor) and v.shape == p.shape:
|
||||
paired_os[idx][k] = True
|
||||
else:
|
||||
paired_os[idx][k] = False
|
||||
return paired_os
|
||||
|
||||
|
||||
def build_checkpoints(max_size: int,
|
||||
model: Module,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
param_to_os: Optional[Dict[str, int]] = None,
|
||||
dist_meta: Optional[Dict[str, ParamDistMeta]] = None,
|
||||
eliminate_replica: bool = False) -> Tuple[List[dict], List[dict], dict]:
|
||||
save_global = dist_meta is None
|
||||
model_state_dict = model.state_dict()
|
||||
optimizer_state_dict = optimizer.state_dict() if optimizer else None
|
||||
meta = {'dist_meta': dist_meta}
|
||||
if optimizer:
|
||||
param_to_os = param_to_os or get_param_to_os(model, optimizer)
|
||||
paired_os = get_paired_os(model_state_dict, optimizer_state_dict, param_to_os)
|
||||
meta['param_to_os'] = param_to_os
|
||||
meta['paired_os'] = paired_os
|
||||
if not save_global and eliminate_replica:
|
||||
# filter dp replicated params
|
||||
model_state_dict = {
|
||||
k: v for k, v in model_state_dict.items() if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
if optimizer:
|
||||
optimizer_state_dict['state'] = {
|
||||
param_to_os[k]: optimizer_state_dict['state'][param_to_os[k]]
|
||||
for k in model_state_dict.keys()
|
||||
if dist_meta[k].used_zero or dist_meta[k].dp_rank == 0
|
||||
}
|
||||
meta['params'] = list(model_state_dict.keys())
|
||||
if len(model_state_dict) == 0:
|
||||
warnings.warn('model state dict is empty, checkpoint is not saved')
|
||||
return [], [], meta
|
||||
model_checkpoints, optimizer_checkpoints = shard_checkpoint(max_size, model_state_dict, optimizer_state_dict,
|
||||
param_to_os)
|
||||
return model_checkpoints, optimizer_checkpoints, meta
|
||||
|
||||
|
||||
def is_duplicated_list(list_: List[Any]) -> bool:
|
||||
if len(list_) == 0:
|
||||
return True
|
||||
elem = list_[0]
|
||||
for x in list_[1:]:
|
||||
if x != elem:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def copy_optimizer_state(src_state: dict, dest_state: dict) -> None:
|
||||
for k, v in src_state.items():
|
||||
if k in dest_state:
|
||||
old_v = dest_state[k]
|
||||
if isinstance(old_v, Tensor):
|
||||
old_v.copy_(v)
|
||||
else:
|
||||
dest_state[k] = v
|
||||
|
||||
|
||||
def optimizer_load_state_dict(optimizer: Optimizer, state_dict: dict, strict: bool = False) -> None:
|
||||
assert optimizer.state_dict()['param_groups'] == state_dict['param_groups']
|
||||
state_dict = deepcopy(state_dict)
|
||||
groups = optimizer.param_groups
|
||||
saved_groups = state_dict['param_groups']
|
||||
idx_to_p: Dict[str, Parameter] = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
missing_keys = list(set(idx_to_p.keys()) - set(state_dict['state'].keys()))
|
||||
unexpected_keys = []
|
||||
error_msgs = []
|
||||
for idx, state in state_dict['state'].items():
|
||||
if idx in idx_to_p:
|
||||
old_state = optimizer.state[idx_to_p[idx]]
|
||||
copy_optimizer_state(state, old_state)
|
||||
else:
|
||||
unexpected_keys.append(idx)
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Unexpected key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0, 'Missing key(s) in state_dict: {}. '.format(', '.join('"{}"'.format(k) for k in missing_keys)))
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(optimizer.__class__.__name__,
|
||||
"\n\t".join(error_msgs)))
|
98
colossalai/utils/checkpoint_io/writer.py
Normal file
98
colossalai/utils/checkpoint_io/writer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from .constant import MODEL_CKPT_FILE_NAME, OPTIM_CKPT_FILE_NAME, META_CKPT_FILE_NAME, OTHER_CKPT_FILE_NAME, GLOBAL_META_FILE_NAME
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
class CheckpointWriter(ABC):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__()
|
||||
self.base_name = base_name
|
||||
self.overwrite = overwrite
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.is_distributed = world_size > 1
|
||||
self.is_main_process = rank == 0
|
||||
|
||||
@abstractmethod
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class DiskCheckpointWriter(CheckpointWriter):
|
||||
|
||||
def __init__(self, base_name: str, overwrite: bool = False, rank: int = 0, world_size: int = 1) -> None:
|
||||
super().__init__(base_name, overwrite, rank, world_size)
|
||||
if not os.path.exists(base_name):
|
||||
os.makedirs(base_name)
|
||||
assert os.path.isdir(base_name), f'"{base_name}" is not a directory'
|
||||
self.model_checkpoint_names = []
|
||||
self.optimizer_checkpoint_names = []
|
||||
self.is_meta_saved: bool = False
|
||||
self._save_global_meta()
|
||||
|
||||
def write(self, name: str, state_dict: dict) -> None:
|
||||
path = os.path.join(self.base_name, name)
|
||||
if os.path.exists(path) and not self.overwrite:
|
||||
raise RuntimeError(f'Save error: Checkpoint "{path}" exists. (overwrite = False)')
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def _save_global_meta(self) -> None:
|
||||
if self.is_main_process:
|
||||
global_meta = {'meta': []}
|
||||
if self.is_distributed:
|
||||
for i in range(self.world_size):
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME.replace('.bin', f'-rank{i}.bin'))
|
||||
else:
|
||||
global_meta['meta'].append(META_CKPT_FILE_NAME)
|
||||
self.write(GLOBAL_META_FILE_NAME, global_meta)
|
||||
|
||||
def _get_checkpoint_name(self, base_name: str, shard_idx: Optional[int] = None) -> str:
|
||||
checkpoint_name = base_name
|
||||
if self.is_distributed:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-rank{self.rank}.bin')
|
||||
if shard_idx is not None:
|
||||
checkpoint_name = checkpoint_name.replace('.bin', f'-shard{shard_idx}.bin')
|
||||
return checkpoint_name
|
||||
|
||||
def save_model(self, model_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save model after saving meta'
|
||||
name = self._get_checkpoint_name(MODEL_CKPT_FILE_NAME, len(self.model_checkpoint_names))
|
||||
self.write(name, model_checkpoint)
|
||||
self.model_checkpoint_names.append(name)
|
||||
|
||||
def save_optimizer(self, optimizer_checkpoint: dict) -> None:
|
||||
assert not self.is_meta_saved, 'Cannot save optimizer after saving meta'
|
||||
name = self._get_checkpoint_name(OPTIM_CKPT_FILE_NAME, len(self.optimizer_checkpoint_names))
|
||||
self.write(name, optimizer_checkpoint)
|
||||
self.optimizer_checkpoint_names.append(name)
|
||||
|
||||
def save_meta(self, meta_checkpoint: dict) -> None:
|
||||
if len(self.model_checkpoint_names) > 0:
|
||||
meta_checkpoint['model'] = self.model_checkpoint_names
|
||||
if len(self.optimizer_checkpoint_names) > 0:
|
||||
meta_checkpoint['optimizer'] = self.optimizer_checkpoint_names
|
||||
self.write(self._get_checkpoint_name(META_CKPT_FILE_NAME), meta_checkpoint)
|
||||
self.is_meta_saved = True
|
||||
|
||||
def save_others(self, kwargs: dict) -> None:
|
||||
if self.is_main_process:
|
||||
self.write(OTHER_CKPT_FILE_NAME, kwargs)
|
Reference in New Issue
Block a user