mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 11:31:58 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
parent
b90835bd32
commit
eb69e640e5
@ -359,6 +359,7 @@ class Booster:
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
@ -374,7 +375,9 @@ class Booster:
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
|
||||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||
self.checkpoint_io.save_optimizer(
|
||||
optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||
)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
@ -94,7 +94,9 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
@ -178,7 +180,13 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: GeminiOptimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
|
@ -24,6 +24,7 @@ from colossalai.checkpoint_io.utils import (
|
||||
get_shard_filename,
|
||||
load_param_groups_into_optimizer,
|
||||
load_shard_state_dict,
|
||||
load_state_dict,
|
||||
load_states_into_optimizer,
|
||||
save_param_groups,
|
||||
save_state_dict,
|
||||
@ -113,7 +114,9 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False, use_async: bool = False
|
||||
):
|
||||
"""Save optimizer to checkpoint but only on master process.
|
||||
|
||||
Args:
|
||||
@ -125,9 +128,34 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
state_dict = optimizer.state_dict()
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
state_dict = optimizer.state_dict(pinned_state_dicts)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
use_async = checkpoint.endswith(".safetensors")
|
||||
if use_async:
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
checkpoint = load_flat(checkpoint)
|
||||
else:
|
||||
checkpoint = load_state_dict(checkpoint)
|
||||
optimizer.load_state_dict(checkpoint)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
@ -136,6 +164,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
gather_dtensor: bool = False,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
|
||||
@ -161,10 +190,16 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
# state_dict only provide only 'param_groups'
|
||||
state_dict = optimizer.optim.state_dict()
|
||||
# state shard would be handled by the low-level zero optimizer
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
|
||||
if use_async:
|
||||
if id(optimizer) not in self.pinned_state_dicts:
|
||||
self.pinned_state_dicts[id(optimizer)] = {}
|
||||
pinned_state_dicts = self.pinned_state_dicts[id(optimizer)]
|
||||
else:
|
||||
pinned_state_dicts = None
|
||||
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard, pinned_state_dicts=pinned_state_dicts)
|
||||
|
||||
# Preparing file paths and index file.
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix, use_safetensors=use_async)
|
||||
index_file = CheckpointIndexFile(checkpoint)
|
||||
index_file.append_meta_data("param_groups", param_group_file)
|
||||
|
||||
@ -184,7 +219,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint, shard_file)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
if use_async:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
|
||||
from colossalai.utils.safetensors import save_nested
|
||||
|
||||
f_writer = AsyncFileWriter(
|
||||
fp=open(checkpoint_file_path, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread"
|
||||
)
|
||||
save_nested(f_writer, shard)
|
||||
self.async_writers.append(f_writer)
|
||||
else:
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
|
||||
|
||||
# Wrap up index file.
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
@ -223,7 +269,12 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
checkpoint_files, _ = ckpt_index_file.get_checkpoint_filenames()
|
||||
|
||||
for shard_file in checkpoint_files:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
if shard_file.endswith(".safetensors"):
|
||||
from colossalai.utils.safetensors import load_flat
|
||||
|
||||
state_dict = load_flat(shard_file)
|
||||
else:
|
||||
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors=False)
|
||||
# shard state dict
|
||||
for param_idx, state in state_dict.items():
|
||||
for k, v in state.items():
|
||||
|
@ -52,7 +52,9 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
@ -113,13 +115,16 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
super().save_sharded_optimizer(
|
||||
optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
|
@ -67,7 +67,9 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
@ -157,7 +159,13 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
||||
model.unwrap().load_state_dict(fsdp_state_dict, strict=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
|
@ -213,6 +213,7 @@ class CheckpointIO(ABC):
|
||||
gather_dtensor=True,
|
||||
prefix: str = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint. Optimizer states saving is not compatible with safetensors.
|
||||
@ -229,11 +230,12 @@ class CheckpointIO(ABC):
|
||||
prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None.
|
||||
size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True.
|
||||
"""
|
||||
|
||||
if shard:
|
||||
self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
self.save_sharded_optimizer(
|
||||
optimizer, checkpoint, gather_dtensor, prefix, size_per_shard, use_async=use_async
|
||||
)
|
||||
else:
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
self.save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor, use_async=use_async)
|
||||
|
||||
# ========================================================
|
||||
# Abstract methods for model loading/saving implementation
|
||||
@ -326,7 +328,13 @@ class CheckpointIO(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer to sharded checkpoint.
|
||||
@ -340,7 +348,9 @@ class CheckpointIO(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer to unsharded checkpoint.
|
||||
|
||||
|
@ -98,6 +98,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
gather_dtensor: bool,
|
||||
prefix: str,
|
||||
size_per_shard: int,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@ -155,6 +156,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||
optimizer: Optimizer,
|
||||
checkpoint: Path,
|
||||
gather_dtensor: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
# TODO(FrankLeeeee): handle distributed tensors
|
||||
save_state_dict(optimizer.state_dict(), checkpoint, use_safetensors=False)
|
||||
|
@ -416,6 +416,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@ -725,7 +726,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool, use_async: bool = False
|
||||
):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
|
@ -369,6 +369,7 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer checkpoint under the given checkpointing path.
|
||||
@ -729,7 +730,13 @@ class MoECheckpointIO(HybridParallelCheckpointIO):
|
||||
dist.barrier()
|
||||
|
||||
# Copied from colossalai.moe
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(
|
||||
self,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool,
|
||||
use_async: bool = False,
|
||||
):
|
||||
"""
|
||||
Save optimizer state dict to a file with given path.
|
||||
|
||||
|
@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
STATES_NAME = "pytorch_optim.bin"
|
||||
SAFE_STATE_NAME = "optimizer.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
|
||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||
|
||||
# ======================================
|
||||
@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
return weights_name, save_index_file
|
||||
|
||||
|
||||
def get_optimizer_base_filenames(prefix: str = None):
|
||||
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base optimizer state filenames
|
||||
"""
|
||||
states_name = STATES_NAME
|
||||
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
|
||||
states_name = add_prefix(states_name, prefix)
|
||||
|
||||
save_index_file = STATES_INDEX_NAME
|
||||
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
param_group_file = GROUP_FILE_NAME
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Any, List, OrderedDict
|
||||
from typing import Any, List, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -78,7 +78,9 @@ def check_state_dict_equal(
|
||||
v1 = v1.to(v2.dtype)
|
||||
assert_close_loose(v1, v2)
|
||||
else:
|
||||
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
|
||||
v2 = tuple(v2)
|
||||
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
|
||||
|
||||
|
||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||
|
@ -1,10 +1,11 @@
|
||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES
|
||||
from safetensors.torch import _TYPES, load_file, safe_open
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
@ -27,34 +28,93 @@ class PreparedData:
|
||||
offset: int
|
||||
|
||||
|
||||
def prepare(data: Dict[str, torch.Tensor]) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
def flatten_dict(nested_dict, parent_key="", separator="^"):
|
||||
"""
|
||||
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
|
||||
|
||||
nested_dict: The input nested dictionary.
|
||||
parent_key: The parent key currently being processed.
|
||||
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
|
||||
"""
|
||||
items = []
|
||||
for k, v in nested_dict.items():
|
||||
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, separator).items())
|
||||
else:
|
||||
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
|
||||
items.append((new_key, v))
|
||||
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(flattened_dict, separator="^"):
|
||||
"""
|
||||
Restore a flattened dictionary back to a multi-level nested dictionary.
|
||||
|
||||
flattened_dict: The flattened dictionary.
|
||||
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
|
||||
"""
|
||||
nested_dict = {}
|
||||
for key, value in flattened_dict.items():
|
||||
keys = key.split(separator)
|
||||
try:
|
||||
keys[0] = int(keys[0])
|
||||
except ValueError:
|
||||
warnings.warn(f"{key[0]} can't convert to integer")
|
||||
d = nested_dict
|
||||
for part in keys[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
assert isinstance(value, torch.Tensor)
|
||||
d[keys[-1]] = value
|
||||
|
||||
return nested_dict
|
||||
|
||||
|
||||
def prepare(
|
||||
data: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> Tuple[PreparedData, List[torch.Tensor], List[str]]:
|
||||
if metadata is not None:
|
||||
assert isinstance(metadata, dict)
|
||||
for k, v in metadata.items():
|
||||
metadata[k] = json.dumps(v)
|
||||
assert isinstance(k, str)
|
||||
assert isinstance(metadata[k], str)
|
||||
|
||||
tensors = []
|
||||
tensor_keys = []
|
||||
metadata = {}
|
||||
header = {}
|
||||
offset = 0
|
||||
|
||||
if metadata is not None:
|
||||
header["__metadata__"] = metadata
|
||||
|
||||
for name, tensor in data.items():
|
||||
n = tensor.numel() * tensor.element_size()
|
||||
tensor_info = TensorInfo(
|
||||
dtype=_TYPES_INV[tensor.dtype], shape=list(tensor.shape), data_offsets=(offset, offset + n)
|
||||
)
|
||||
offset += n
|
||||
metadata[name] = asdict(tensor_info)
|
||||
header[name] = asdict(tensor_info)
|
||||
tensors.append(tensor)
|
||||
tensor_keys.append(name)
|
||||
|
||||
metadata_buf = json.dumps(metadata).encode("utf-8")
|
||||
header_buf = json.dumps(header).encode("utf-8")
|
||||
|
||||
extra = (8 - len(metadata_buf) % 8) % 8
|
||||
metadata_buf += b" " * extra
|
||||
extra = (8 - len(header_buf) % 8) % 8
|
||||
header_buf += b" " * extra
|
||||
|
||||
n = len(metadata_buf)
|
||||
n = len(header_buf)
|
||||
|
||||
return PreparedData(n=n, header_bytes=metadata_buf, offset=offset), tensors, tensor_keys
|
||||
return PreparedData(n=n, header_bytes=header_buf, offset=offset), tensors, tensor_keys
|
||||
|
||||
|
||||
def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||
prepared_data, tensors, _ = prepare(state_dict)
|
||||
def save(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
prepared_data, tensors, _ = prepare(state_dict, metadata)
|
||||
n, header_bytes, _ = prepared_data.n, prepared_data.header_bytes, prepared_data.offset
|
||||
|
||||
f_writer.write(n.to_bytes(8, byteorder="little"))
|
||||
@ -64,6 +124,13 @@ def save(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None
|
||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||
|
||||
|
||||
def save_nested(
|
||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
||||
) -> None:
|
||||
flatten_data = flatten_dict(state_dict)
|
||||
save(f_writer, flatten_data, metadata)
|
||||
|
||||
|
||||
def move_and_save(
|
||||
f_writer: AsyncFileWriter,
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
@ -81,3 +148,16 @@ def move_and_save(
|
||||
f_writer.write_tensor(state_dict[name], state_dict_pinned[name])
|
||||
else:
|
||||
f_writer.write_tensor(state_dict[name])
|
||||
|
||||
|
||||
def load_flat(checkpoint_path):
|
||||
with safe_open(checkpoint_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
state_dict_load = load_file(checkpoint_path)
|
||||
state_dict = unflatten_dict(state_dict_load)
|
||||
if metadata is None:
|
||||
return state_dict
|
||||
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
|
||||
combined_state_dict = {"state": state_dict}
|
||||
combined_state_dict.update(metadata)
|
||||
return combined_state_dict
|
||||
|
@ -770,7 +770,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
return {"state": packed_state, "param_groups": param_groups}
|
||||
|
||||
def state_dict(self) -> Dict:
|
||||
def state_dict(self, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None) -> Dict:
|
||||
"""Return a state_dict same with DDP
|
||||
|
||||
Returns:
|
||||
@ -779,15 +779,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
zero_state = dict()
|
||||
device = get_accelerator().get_current_device()
|
||||
for param, state in self.optim.state.items():
|
||||
if pinned_state_dicts and param not in pinned_state_dicts:
|
||||
pinned_state_dicts[param] = {}
|
||||
zero_state[param] = copy.deepcopy(state)
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if pinned_state_dicts and k not in pinned_state_dicts[param]:
|
||||
pinned_state_dicts[param][k] = torch.empty_like(working_param, pin_memory=True, device="cpu")
|
||||
working_param = self.master_to_working_param[id(param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
gathered_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(gathered_tensor, v.to(device).flatten(), pg)
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
zero_state[param][k] = param_state
|
||||
param_state = gathered_tensor[: working_param.numel()].reshape_as(working_param)
|
||||
if pinned_state_dicts:
|
||||
pinned_state_dicts[param][k].copy_(param_state)
|
||||
zero_state[param][k] = pinned_state_dicts[param][k]
|
||||
else:
|
||||
zero_state[param][k] = param_state.cpu()
|
||||
|
||||
states_dict = self._pack_state(zero_state)
|
||||
|
||||
@ -822,7 +830,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
|
||||
self.optim.load_state_dict(zero_state_dict)
|
||||
|
||||
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
|
||||
def state_dict_shard(
|
||||
self, max_shard_size: int = 1024, pinned_state_dicts: Optional[Dict[str, Dict[str, torch.Tensor]]] = None
|
||||
) -> Iterator[Tuple[Dict, int]]:
|
||||
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||
Only include the 'state' in state_dict.
|
||||
|
||||
@ -847,18 +857,27 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
for param_idx, states in local_states.items():
|
||||
current_block_size = 0
|
||||
current_block = copy.deepcopy(states)
|
||||
|
||||
if pinned_state_dicts and param_idx not in pinned_state_dicts:
|
||||
pinned_state_dicts[param_idx] = {}
|
||||
master_param = idx2master[param_idx]
|
||||
working_param = self.master_to_working_param[id(master_param)]
|
||||
pg = self.param_to_pg[working_param]
|
||||
|
||||
for k, v in states.items():
|
||||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
if pinned_state_dicts and k not in pinned_state_dicts[param_idx]:
|
||||
pinned_state_dicts[param_idx][k] = torch.empty_like(
|
||||
working_param, pin_memory=True, device="cpu"
|
||||
)
|
||||
state_tensor = torch.empty(v.numel() * get_nd_world_size(pg), device=device, dtype=v.dtype)
|
||||
all_gather_into_flat_tensor_nd(state_tensor, v.to(device).flatten(), pg)
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param).cpu()
|
||||
state_tensor = state_tensor[: working_param.numel()].reshape_as(working_param)
|
||||
if pinned_state_dicts:
|
||||
pinned_state_dicts[param_idx][k].copy_(state_tensor)
|
||||
current_block[k] = pinned_state_dicts[param_idx][k]
|
||||
else:
|
||||
current_block[k] = state_tensor.cpu()
|
||||
current_block_size += state_tensor.numel()
|
||||
current_block[k] = state_tensor
|
||||
|
||||
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
|
||||
yield ret_block, ret_block_size
|
||||
|
@ -51,6 +51,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
model_ckpt_path = f"{model_ckpt_path}.pt"
|
||||
if not shard and use_async:
|
||||
model_ckpt_path = f"{model_ckpt_path}.safetensors"
|
||||
if not shard and use_async:
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer.safetensors"
|
||||
booster.save_model(
|
||||
model,
|
||||
model_ckpt_path,
|
||||
@ -59,7 +61,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool, us
|
||||
)
|
||||
|
||||
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, use_async=use_async)
|
||||
booster.checkpoint_io._sync_d2h()
|
||||
booster.checkpoint_io._sync_io()
|
||||
dist.barrier()
|
||||
@ -139,7 +141,6 @@ def run_fn(stage, shard, offload, model_fn, data_gen_fn, output_transform_fn, lo
|
||||
assert torch.equal(
|
||||
working_shard, master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device)
|
||||
)
|
||||
|
||||
new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict())
|
||||
|
||||
|
127
tests/test_checkpoint_io/test_safetensors_async_io.py
Normal file
127
tests/test_checkpoint_io/test_safetensors_async_io.py
Normal file
@ -0,0 +1,127 @@
|
||||
import tempfile
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils.safetensors import load_flat, save_nested
|
||||
|
||||
try:
|
||||
from tensornvme.async_file_io import AsyncFileWriter
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||
|
||||
from colossalai.testing import check_state_dict_equal
|
||||
|
||||
|
||||
def test_save_load():
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
optimizer_state_dict = {
|
||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
||||
}
|
||||
# group_dict = {"param_groups": [0, 1, 2]}
|
||||
group_dict = {
|
||||
"param_groups": [
|
||||
{
|
||||
"lr": 0.001,
|
||||
"betas": (0.9, 0.999),
|
||||
"eps": 1e-08,
|
||||
"weight_decay": 0,
|
||||
"bias_correction": True,
|
||||
"params": [
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
14,
|
||||
15,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
32,
|
||||
33,
|
||||
34,
|
||||
35,
|
||||
36,
|
||||
37,
|
||||
38,
|
||||
39,
|
||||
40,
|
||||
41,
|
||||
42,
|
||||
43,
|
||||
44,
|
||||
45,
|
||||
46,
|
||||
47,
|
||||
48,
|
||||
49,
|
||||
50,
|
||||
51,
|
||||
52,
|
||||
53,
|
||||
54,
|
||||
55,
|
||||
56,
|
||||
57,
|
||||
58,
|
||||
59,
|
||||
60,
|
||||
61,
|
||||
],
|
||||
}
|
||||
]
|
||||
}
|
||||
metadata = deepcopy(group_dict)
|
||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
|
||||
save_nested(f_writer, optimizer_state_dict, metadata)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(optimizer_saved_path)
|
||||
state_dict = load_state_dict["state"]
|
||||
group = {"param_groups": load_state_dict["param_groups"]}
|
||||
check_state_dict_equal(optimizer_state_dict, state_dict)
|
||||
check_state_dict_equal(group_dict, group)
|
||||
|
||||
model_state_dict = {
|
||||
"module.weight0": torch.rand((1024, 1024)),
|
||||
"module.weight1": torch.rand((1024, 1024)),
|
||||
"module.weight2": torch.rand((1024, 1024)),
|
||||
}
|
||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||
save_nested(f_writer, model_state_dict)
|
||||
f_writer.sync_before_step()
|
||||
f_writer.synchronize()
|
||||
f_writer.fp.close()
|
||||
|
||||
load_state_dict = load_flat(model_saved_path)
|
||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
Loading…
Reference in New Issue
Block a user