[optim] hotfix adam load (#6146)

* [optim] hotfix adam load

* [checkpointio] fix optimizer async io

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [checkpointio] update test

* [checkpointio] update test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu 2024-11-20 16:36:37 +08:00 committed by GitHub
parent 5caad13055
commit cf519dac6a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 76 deletions

View File

@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
from colossalai.utils.safetensors import save_nested from colossalai.utils.safetensors import save_nested
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread") f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]}) save_nested(f_writer, state_dict)
self.async_writers.append(f_writer) self.async_writers.append(f_writer)
else: else:
save_state_dict(state_dict, checkpoint, use_safetensors=False) save_state_dict(state_dict, checkpoint, use_safetensors=False)

View File

@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification # if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def load_state_dict(self, state_dict):
super().load_state_dict(state_dict)
for group in self.param_groups:
for p in group["params"]:
state = self.state[p]
if "step" in state and isinstance(state["step"], torch.Tensor):
state["step"] = int(state["step"].item())
def torch_adam_update( def torch_adam_update(
self, self,
data, data,

View File

@ -1,4 +1,4 @@
from typing import Any, List, OrderedDict, Tuple from typing import Any, List, OrderedDict
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -78,9 +78,7 @@ def check_state_dict_equal(
v1 = v1.to(v2.dtype) v1 = v1.to(v2.dtype)
assert_close_loose(v1, v2) assert_close_loose(v1, v2)
else: else:
if isinstance(v1, Tuple) and not isinstance(v2, Tuple): assert v1 == v2, f"{v1} not equals to {v2}"
v2 = tuple(v2)
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):

View File

@ -1,6 +1,5 @@
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214 # a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
import json import json
import warnings
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -12,6 +11,26 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
_TYPES_INV = {v: k for k, v in _TYPES.items()} _TYPES_INV = {v: k for k, v in _TYPES.items()}
import io
from torch.distributed.distributed_c10d import _pickler, _unpickler
def _object_to_tensor(obj, device):
f = io.BytesIO()
_pickler(f).dump(obj)
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
# Otherwise, it will casue 100X slowdown.
# See: https://github.com/pytorch/pytorch/issues/65696
byte_tensor = torch.ByteTensor(byte_storage).to(device)
return byte_tensor
def _tensor_to_object(tensor, tensor_size):
tensor = tensor.cpu()
buf = tensor.numpy().tobytes()[:tensor_size]
return _unpickler(io.BytesIO(buf)).load()
@dataclass @dataclass
@ -28,49 +47,68 @@ class PreparedData:
offset: int offset: int
def flatten_dict(nested_dict, parent_key="", separator="^"): def _cast_to_tensor(obj):
""" if isinstance(obj, torch.Tensor):
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator. return obj
return _object_to_tensor(obj, "cpu")
nested_dict: The input nested dictionary.
parent_key: The parent key currently being processed. def _cast_to_object(tensor: torch.Tensor):
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary." return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
"""
items = []
for k, v in nested_dict.items(): def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k) flat_dict = {}
if isinstance(v, dict): non_tensor_keys = []
items.extend(flatten_dict(v, new_key, separator).items()) if "state" in state_dict:
# 3-level dict
states = state_dict["state"]
else: else:
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v # 2-level dict, usually for optimizer state dict shard
items.append((new_key, v)) states = state_dict
return dict(items) for idx, d in states.items():
for k, v in d.items():
nested_key = f"state{seperator}{idx}{seperator}{k}"
if not isinstance(v, torch.Tensor):
non_tensor_keys.append(nested_key)
flat_dict[nested_key] = _cast_to_tensor(v)
if "param_groups" in state_dict:
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
non_tensor_keys.append("param_groups")
if len(non_tensor_keys) > 0:
metadata = {"non_tensor_keys": non_tensor_keys}
else:
metadata = None
return flat_dict, metadata
def unflatten_dict(flattened_dict, separator="^"): def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
""" state_dict = {}
Restore a flattened dictionary back to a multi-level nested dictionary. if metadata is not None:
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
else:
non_tensor_keys = []
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
if "param_groups" in flat_dict:
# 3-level dict
state_dict["param_groups"] = flat_dict.pop("param_groups")
state_dict["state"] = {}
states = state_dict["state"]
else:
# 2-level dict, usually for optimizer state dict shard
states = state_dict
flattened_dict: The flattened dictionary. for k, v in flat_dict.items():
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary. parts = k.split(seperator)
""" assert len(parts) == 3 and parts[0] == "state"
nested_dict = {} idx = int(parts[1])
for key, value in flattened_dict.items(): key = parts[2]
keys = key.split(separator) if idx not in states:
try: states[idx] = {}
keys[0] = int(keys[0]) states[idx][key] = v
except ValueError:
warnings.warn(f"{key[0]} can't convert to integer")
d = nested_dict
for part in keys[:-1]:
if part not in d:
d[part] = {}
d = d[part]
assert isinstance(value, torch.Tensor)
d[keys[-1]] = value
return nested_dict return state_dict
def prepare( def prepare(
@ -124,10 +162,8 @@ def save(
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset) f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
def save_nested( def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None flatten_data, metadata = _flatten_optim_state_dict(state_dict)
) -> None:
flatten_data = flatten_dict(state_dict)
save(f_writer, flatten_data, metadata) save(f_writer, flatten_data, metadata)
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
with safe_open(checkpoint_path, framework="pt") as f: with safe_open(checkpoint_path, framework="pt") as f:
metadata = f.metadata() metadata = f.metadata()
state_dict_load = load_file(checkpoint_path) state_dict_load = load_file(checkpoint_path)
state_dict = unflatten_dict(state_dict_load) state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
if metadata is None:
return state_dict return state_dict
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
combined_state_dict = {"state": state_dict}
combined_state_dict.update(metadata)
return combined_state_dict

View File

@ -1,9 +1,9 @@
import tempfile import tempfile
from copy import deepcopy
import torch import torch
from safetensors.torch import load_file
from colossalai.utils.safetensors import load_flat, save_nested from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
try: try:
from tensornvme.async_file_io import AsyncFileWriter from tensornvme.async_file_io import AsyncFileWriter
@ -11,17 +11,29 @@ except ModuleNotFoundError:
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer") raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
from colossalai.testing import check_state_dict_equal from colossalai.testing import check_state_dict_equal
from colossalai.utils import get_current_device
def test_save_load(): def test_save_load():
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
optimizer_state_dict = { optimizer_state_dict = {
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "state": {
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, 0: {
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))}, "step": torch.tensor(1.0),
} "exp_avg": torch.rand((1024, 1024)),
# group_dict = {"param_groups": [0, 1, 2]} "exp_avg_sq": torch.rand((1024, 1024)),
group_dict = { },
1: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
2: {
"step": torch.tensor(1.0),
"exp_avg": torch.rand((1024, 1024)),
"exp_avg_sq": torch.rand((1024, 1024)),
},
},
"param_groups": [ "param_groups": [
{ {
"lr": 0.001, "lr": 0.001,
@ -94,22 +106,26 @@ def test_save_load():
61, 61,
], ],
} }
] ],
} }
metadata = deepcopy(group_dict)
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors" optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict)
save_nested(f_writer, optimizer_state_dict, metadata)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_flat(optimizer_saved_path) load_state_dict = load_flat(optimizer_saved_path)
state_dict = load_state_dict["state"] check_state_dict_equal(load_state_dict, optimizer_state_dict)
group = {"param_groups": load_state_dict["param_groups"]}
check_state_dict_equal(optimizer_state_dict, state_dict) optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
check_state_dict_equal(group_dict, group) f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, optimizer_state_dict["state"])
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
model_state_dict = { model_state_dict = {
"module.weight0": torch.rand((1024, 1024)), "module.weight0": torch.rand((1024, 1024)),
@ -118,10 +134,20 @@ def test_save_load():
} }
model_saved_path = f"{tempdir}/save_model.safetensors" model_saved_path = f"{tempdir}/save_model.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread") f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
save_nested(f_writer, model_state_dict) save(f_writer, model_state_dict)
f_writer.sync_before_step() f_writer.sync_before_step()
f_writer.synchronize() f_writer.synchronize()
f_writer.fp.close() f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
load_state_dict = load_flat(model_saved_path) check_state_dict_equal(model_state_dict, load_state_dict)
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
f_writer.sync_before_step()
f_writer.synchronize()
f_writer.fp.close()
load_state_dict = load_file(model_saved_path)
check_state_dict_equal(model_state_dict, load_state_dict) check_state_dict_equal(model_state_dict, load_state_dict)