mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 18:40:28 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
import random
|
||||
import warnings
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
import os
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -20,6 +23,13 @@ from colossalai.utils import get_current_device
|
||||
from colossalai.zero import GeminiDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero.gemini.memory_tracer import MemStats
|
||||
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
get_base_filenames,
|
||||
get_shard_filename
|
||||
)
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile
|
||||
|
||||
from .plugin_base import Plugin
|
||||
|
||||
__all__ = ['GeminiPlugin']
|
||||
@@ -62,6 +72,40 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_sharded_model(self, model: GeminiDDP, checkpoint_path: str, gather_dtensor: bool = False, variant: Optional[str] = None, max_shard_size: int = 1024, use_safetensors: bool = False):
|
||||
"""
|
||||
Save sharded model
|
||||
"""
|
||||
state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32)
|
||||
weights_name, save_index_file = get_base_filenames(variant, use_safetensors)
|
||||
total_size = 0
|
||||
index_file = CheckpointIndexFile(checkpoint_path)
|
||||
for idx, shard_pair in enumerate(state_dict_shard):
|
||||
if not self.coordinator.is_master():
|
||||
continue
|
||||
shard = shard_pair[0]
|
||||
shard_file = get_shard_filename(weights_name, idx)
|
||||
total_size = total_size + shard_pair[1]
|
||||
for key in shard.keys():
|
||||
index_file.append_weight_map(key, shard_file)
|
||||
|
||||
checkpoint_file_path = os.path.join(checkpoint_path, shard_file)
|
||||
save_state_dict(shard, checkpoint_file_path, use_safetensors)
|
||||
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
logging.info(
|
||||
f"The model is going to be split to checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
|
||||
def load_sharded_model(self, model: GeminiDDP, checkpoint_index_file: Path, strict: bool = False, use_safetensors: bool = False):
|
||||
"""
|
||||
load shard model, load model from multiple files
|
||||
"""
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
class GeminiModel(ModelWrapper):
|
||||
|
||||
|
Reference in New Issue
Block a user