[booster] gemini plugin support shard checkpoint (#3610)

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin add shard checkpoint save/load

* gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

* [API Refactoring]gemini plugin support shard checkpoint

---------

Co-authored-by: luchen <luchen@luchendeMBP.lan>
Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
jiangmingyan 2023-05-05 14:37:21 +08:00 committed by GitHub
parent 0f785cb1f3
commit 307894f74d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 268 additions and 96 deletions

View File

@ -1,6 +1,9 @@
import random import random
import warnings import warnings
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
from pathlib import Path
import os
import logging
import numpy as np import numpy as np
import torch import torch
@ -20,6 +23,13 @@ from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
from colossalai.checkpoint_io.utils import (
get_base_filenames,
get_shard_filename
)
from colossalai.checkpoint_io import CheckpointIndexFile
from .plugin_base import Plugin from .plugin_base import Plugin
__all__ = ['GeminiPlugin'] __all__ = ['GeminiPlugin']
@ -62,6 +72,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master(): if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint) super().save_lr_scheduler(lr_scheduler, checkpoint)
def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
"""
Save sharded model
"""
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
if not self.coordinator.is_master():
continue
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors)
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
logging.info(
f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}."
)
def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
"""
load shard model, load model from multiple files
"""
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
class GeminiModel(ModelWrapper): class GeminiModel(ModelWrapper):

View File

@ -1,12 +1,12 @@
from pathlib import Path from pathlib import Path
from functools import reduce
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
import logging import logging
import os import os
import json
import gc import gc
from typing import Optional from typing import Optional, Iterator, OrderedDict, Tuple
from .checkpoint_io_base import CheckpointIO from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
@ -18,10 +18,9 @@ from .utils import (
shard_checkpoint, shard_checkpoint,
load_shard_state_dict, load_shard_state_dict,
load_state_dict_into_model, load_state_dict_into_model,
add_variant get_shard_filename,
get_base_filenames
) )
from .utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = ['GeneralCheckpointIO'] __all__ = ['GeneralCheckpointIO']
@ -85,30 +84,32 @@ class GeneralCheckpointIO(CheckpointIO):
# shard checkpoint # shard checkpoint
state_dict = model.state_dict() state_dict = model.state_dict()
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME state_dict_shard = shard_checkpoint(state_dict, max_shard_size=max_shard_size)
weights_name = add_variant(weights_name, variant)
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
total_size = 0
index_file = CheckpointIndexFile(checkpoint_path)
for idx, shard_pair in enumerate(state_dict_shard):
shard = shard_pair[0]
shard_file = get_shard_filename(weights_name, idx)
total_size = total_size + shard_pair[1]
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
# Save the model
for shard_file, shard in shards.items():
checkpoint_file_path = os.path.join(checkpoint_path, shard_file) checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
save_state_dict(shard, checkpoint_file_path, use_safetensors) save_state_dict(shard, checkpoint_file_path, use_safetensors)
# save index file index_file.append_meta_data("total_size", total_size)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME index_file.write_index_file(save_index_file)
save_index_file = os.path.join(checkpoint_path, add_variant(save_index_file, variant))
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
logging.info( logging.info(
f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " f"The model is going to be split to checkpoint shards. "
f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}." f"index located at {save_index_file}."
) )
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False): def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False,
use_safetensors: bool = False, load_sub_module: bool = True):
""" """
load shard model, load model from multiple files load shard model, load model from multiple files
""" """
@ -122,17 +123,21 @@ class GeneralCheckpointIO(CheckpointIO):
# read checkpoint index file # read checkpoint index file
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file) ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames() checkpoint_files, _ = ckpt_index_file.get_checkpoint_fileanames()
missing_keys = ckpt_index_file.get_all_param_names() missing_keys = []
for shard_file in checkpoint_files: for shard_file in checkpoint_files:
state_dict = load_shard_state_dict(Path(shard_file), use_safetensors) state_dict = load_shard_state_dict(Path(shard_file), use_safetensors)
load_state_dict_into_model(model, state_dict, missing_keys, strict) load_state_dict_into_model(model, state_dict, missing_keys, strict, load_sub_module)
del state_dict del state_dict
gc.collect() gc.collect()
if strict and len(missing_keys) > 0: if strict:
error_msgs = 'Missing key(s) in state_dict: {}. '.format( remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
', '.join('"{}"'.format(k) for k in missing_keys)) if len(remain_keys) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( error_msgs = 'Missing key(s) in state_dict: {}. '.format(
self.__class__.__name__, "\n\t".join(error_msgs))) ', '.join('"{}"'.format(k) for k in missing_keys))
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))

View File

@ -1,6 +1,8 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Any, List, Union from typing import Any, List, Union
import os
import json
from .utils import is_dtensor_checkpoint from .utils import is_dtensor_checkpoint
@ -18,8 +20,8 @@ class CheckpointIndexFile:
>>> index.export('new_index.json') >>> index.export('new_index.json')
""" """
def __init__(self) -> None: def __init__(self, root_path=None) -> None:
self.root_path = None self.root_path = root_path
self.metadata: dict = dict() self.metadata: dict = dict()
self.weight_map: dict = dict() self.weight_map: dict = dict()
@ -154,3 +156,13 @@ class CheckpointIndexFile:
Get all the weight keys. Get all the weight keys.
""" """
return list(self.weight_map.keys()) return list(self.weight_map.keys())
def write_index_file(self, save_index_file):
"""
Wriete index file.
"""
save_index_file = os.path.join(self.root_path, save_index_file)
index = {"metadata": self.metadata, "weight_map": self.weight_map}
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)

View File

@ -2,7 +2,7 @@
from pathlib import Path from pathlib import Path
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List, Dict, Mapping, OrderedDict, Optional, Tuple from typing import List, Mapping, OrderedDict, Optional, Tuple, Iterator
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
import re import re
@ -77,55 +77,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
# ====================================== # ======================================
# Helper functions for saving shard file # Helper functions for saving shard file
# ====================================== # ======================================
def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024, weights_name: str = WEIGHTS_NAME): def shard_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
""" """
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size. given size.
""" """
sharded_state_dicts = []
current_block = {} current_block = {}
current_block_size = 0 current_block_size = 0
total_size = 0
for key, weight in state_dict.items(): for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor: if type(weight) != DTensor:
weight_size = calculate_tensor_size(weight) weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size: if current_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block) ret_block = current_block
ret_block_size = current_block_size
current_block = {} current_block = {}
current_block_size = 0 current_block_size = 0
current_block[key] = weight current_block[key] = weight
current_block_size += weight_size current_block_size += weight_size
total_size += weight_size
# Add the last block if ret_block != None:
sharded_state_dicts.append(current_block) yield ret_block, ret_block_size
# If we only have one shard, we return it yield current_block, current_block_size
if len(sharded_state_dicts) == 1:
return {weights_name: sharded_state_dicts[0]}, None
# Otherwise, let's build the index
weight_map = {}
shards = {}
for idx, shard in enumerate(sharded_state_dicts):
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}-of-{len(sharded_state_dicts):05d}.bin")
shard_file = shard_file.replace(
".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors"
)
shards[shard_file] = shard
for key in shard.keys():
weight_map[key] = shard_file
# Add the metadata
metadata = {"total_size": total_size}
index = {"metadata": metadata, "weight_map": weight_map}
return shards, index
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False): def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
""" """
@ -146,7 +126,7 @@ def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool =False):
else: else:
return torch.load(checkpoint_file) return torch.load(checkpoint_file)
def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False): def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missing_keys: List, strict: bool = False, load_sub_module: bool = True):
r"""Copies parameters and buffers from :attr:`state_dict` into r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. this module and its descendants.
@ -167,29 +147,22 @@ def load_state_dict_into_model(model: nn.Module, state_dict: torch.Tensor, missi
if metadata is not None: if metadata is not None:
state_dict._metadata = metadata state_dict._metadata = metadata
def load(module: nn.Module, state_dict, prefix=""): def load(module: nn.Module, state_dict, prefix="", load_sub_module: bool = True):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) args = (state_dict, prefix, local_metadata, True, sub_missing_keys, [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this # Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict # state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0: if len([key for key in state_dict if key.startswith(prefix)]) > 0:
module._load_from_state_dict(*args) module._load_from_state_dict(*args)
if load_sub_module:
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
for name, child in module._modules.items(): load(model, state_dict, "", load_sub_module)
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict, "")
del load del load
# deal with missing key missing_keys = missing_keys.append(sub_missing_keys)
if len(missing_keys) > 0:
deleted_keys = []
for key in missing_keys:
if key not in sub_missing_keys:
deleted_keys.append(key)
for key in deleted_keys:
missing_keys.remove(key)
if strict: if strict:
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
@ -417,3 +390,24 @@ def add_variant(weights_name: str, variant: Optional[str] = None) -> str:
weights_name = ".".join(splits) weights_name = ".".join(splits)
return weights_name return weights_name
def get_base_filenames(variant: str=None, use_safetensors: bool=False):
"""
generate base weight filenames
"""
weights_name = SAFE_WEIGHTS_NAME if use_safetensors else WEIGHTS_NAME
weights_name = add_variant(weights_name, variant)
save_index_file = SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME
save_index_file = add_variant(save_index_file, variant)
return weights_name, save_index_file
def get_shard_filename(weights_name: str, idx: int):
"""
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
return shard_file

View File

@ -2,7 +2,7 @@ import itertools
from collections import OrderedDict from collections import OrderedDict
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union, Tuple, Set
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -96,8 +96,35 @@ class ZeroDDP(ColoDDP):
param_name = m_name + '.' + p_name if m_name else p_name param_name = m_name + '.' + p_name if m_name else p_name
self.name2param[param_name] = p_var self.name2param[param_name] = p_var
super().__init__(module, process_group=ColoProcessGroup()) super().__init__(module, process_group=ColoProcessGroup())
self._non_persistent_buffers_set=self._get_non_persistent_buffers_set(module)
self._cast_buffers() self._cast_buffers()
def _get_non_persistent_buffers_set(self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = '', remove_duplicate: bool = True):
r"""
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
"""
if memo is None:
memo = set()
self_non_persistent_set = set()
if module not in memo:
if remove_duplicate:
memo.add(module)
self_non_persistent_set = set(map(lambda key: prefix + ('.' if prefix else '') + key, module._non_persistent_buffers_set))
for name, sub_module in module._modules.items():
if sub_module is None:
continue
submodule_prefix = prefix + ('.' if prefix else '') + name
child_non_persistent_set = self._get_non_persistent_buffers_set(sub_module, memo, submodule_prefix, remove_duplicate)
self_non_persistent_set = set.union(self_non_persistent_set, child_non_persistent_set)
return self_non_persistent_set
def _post_forward(self): def _post_forward(self):
"""This function is only triggered for inference. """This function is only triggered for inference.
""" """
@ -604,7 +631,7 @@ class ZeroDDP(ColoDDP):
keep_vars: bool = False, keep_vars: bool = False,
max_shard_size: int = 1024, max_shard_size: int = 1024,
only_rank_0: bool = True, only_rank_0: bool = True,
dtype: torch.dtype = torch.float16) -> Iterator[OrderedDict]: dtype: torch.dtype = torch.float16) -> Iterator[Tuple[OrderedDict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``. """Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included. Both parameters and persistent buffers (e.g. running averages) are included.
@ -644,9 +671,9 @@ class ZeroDDP(ColoDDP):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype)) gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param) gathered_param = gathered_param_buffer.pop(fp32_param)
block = sharder.append(prefix + name, gathered_param) block, block_size = sharder.append(prefix + name, gathered_param)
if block is not None: if block is not None:
yield block yield block, block_size
del fp16_to_fp32 del fp16_to_fp32
del gathered_param_buffer del gathered_param_buffer
@ -655,19 +682,19 @@ class ZeroDDP(ColoDDP):
for name, buf in self.named_buffers(): for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach() buffer = buf if keep_vars else buf.detach()
block = sharder.append(prefix + name, buffer) block, block_size = sharder.append(prefix + name, buffer)
if block is not None: if block is not None:
yield block yield block, block_size
# save extra states # save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state() extra_state = self.get_extra_state()
block = sharder.append(extra_state_key, extra_state) block, block_size = sharder.append(extra_state_key, extra_state)
if block is not None: if block is not None:
yield block yield block, block_size
yield sharder.current_block yield sharder.current_block, sharder.current_block_size
class _StateDictSharder: class _StateDictSharder:
@ -677,16 +704,18 @@ class _StateDictSharder:
self.current_block = OrderedDict() self.current_block = OrderedDict()
self.current_block_size = 0 self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]: def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor) tensor_size = calculate_tensor_size(tensor)
ret_block = None ret_block = None
ret_block_size = 0
if self.current_block_size + tensor_size > self.max_shard_size: if self.current_block_size + tensor_size > self.max_shard_size:
ret_block = self.current_block ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict() self.current_block = OrderedDict()
self.current_block_size = 0 self.current_block_size = 0
self.current_block[name] = tensor self.current_block[name] = tensor
self.current_block_size += tensor_size self.current_block_size += tensor_size
return ret_block return ret_block, ret_block_size
class GeminiDDP(ZeroDDP): class GeminiDDP(ZeroDDP):

View File

@ -1,16 +1,21 @@
import tempfile import tempfile
import pytest import pytest
import torch import torch
import logging
from torch.optim import Adam from torch.optim import Adam
from torchvision.models import resnet18 from torchvision.models import resnet18
from pathlib import Path
import os
import subprocess
from colossalai.checkpoint_io import GeneralCheckpointIO from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
# ======== # ========
# Note: # Note:
# 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now # 1. due to checkpoint IO can be quite slow if tested with all models, we will only test on resnet for now
@ -83,7 +88,6 @@ def test_sharded_checkpoint(use_safetensors: bool):
suffix = ".bin" suffix = ".bin"
WEIGHTS_INDEX_NAME = "model.bin.index.json" WEIGHTS_INDEX_NAME = "model.bin.index.json"
# model_ckpt_dir = tempfile.TemporaryDirectory(suffix=suffix)
model_ckpt_dir = tempfile.TemporaryDirectory() model_ckpt_dir = tempfile.TemporaryDirectory()
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile() optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
@ -104,6 +108,87 @@ def test_sharded_checkpoint(use_safetensors: bool):
recursive_check(model.state_dict(), new_model.state_dict()) recursive_check(model.state_dict(), new_model.state_dict())
recursive_check(optimizer.state_dict(), new_optimizer.state_dict()) recursive_check(optimizer.state_dict(), new_optimizer.state_dict())
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['bert'])
@parameterize('use_safetensors', [True, False])
def hf_load_colossalai_checkpoint(placement_policy, model_name, use_safetensors: bool):
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig, BertForSequenceClassification
model_ckpt_dir = tempfile.TemporaryDirectory()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=get_current_device()):
bert_model = model_builder()
bert_model.config.save_pretrained(save_directory=model_ckpt_dir.name)
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
bert_model = ZeroDDP(bert_model, gemini_manager)
bert_model.train()
ckpt_io = GeminiCheckpointIO()
if ckpt_io.coordinator.is_master():
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
ckpt_io.save_model(bert_model, model_ckpt_dir.name, True, True, "", (model_size / 3), use_safetensors=use_safetensors)
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
recursive_check(bert_model.state_dict(only_rank_0=True, dtype=torch.float32), new_bert_model.state_dict())
model_ckpt_dir.cleanup()
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
@parameterize('use_safetensors', [True, False])
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, *_ = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
new_model = model_builder()
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
new_chunk_manager = ChunkManager(new_config_dict)
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
new_model = ZeroDDP(new_model, new_gemini_manager)
model_ckpt_dir = tempfile.TemporaryDirectory()
ckpt_io = GeminiCheckpointIO()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
ckpt_io.save_model(model, model_ckpt_dir.name, True, True, "epoch", (model_size / 3), use_safetensors=use_safetensors)
# load model
if ckpt_io.coordinator.is_master():
ckpt_io.load_model(new_model, model_ckpt_dir.name, strict=True)
model_dict = model.state_dict(only_rank_0=True)
new_model_dict = new_model.state_dict(only_rank_0=True)
recursive_check(model_dict, new_model_dict)
model_ckpt_dir.cleanup()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
hf_load_colossalai_checkpoint()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4, 4])
@rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size)
# do recursive check for the optimizer state dict # do recursive check for the optimizer state dict
# if the value is a dict, compare its values # if the value is a dict, compare its values
@ -117,10 +202,14 @@ def recursive_check(d1, d2):
elif isinstance(v, list): elif isinstance(v, list):
for i in range(len(v)): for i in range(len(v)):
if isinstance(v[i], torch.Tensor): if isinstance(v[i], torch.Tensor):
v[i] = v[i].to("cpu")
d2[k][i] = d2[k][i].to("cpu")
assert torch.equal(v[i], d2[k][i]) assert torch.equal(v[i], d2[k][i])
else: else:
assert v[i] == d2[k][i] assert v[i] == d2[k][i]
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
v = v.to("cpu")
d2[k] = d2[k].to("cpu")
assert torch.equal(v, d2[k]) assert torch.equal(v, d2[k])
else: else:
assert v == d2[k] assert v == d2[k]

View File

@ -31,14 +31,13 @@ def exam_state_dict(placement_policy, model_name: str):
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set() accumulated_keys = set()
# ensure number of shards > 1 # ensure number of shards > 1
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items(): for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated." assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key) accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary." assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal." assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')