mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[booster] gemini plugin support shard checkpoint (#3610)
* gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin add shard checkpoint save/load * gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint * [API Refactoring]gemini plugin support shard checkpoint --------- Co-authored-by: luchen <luchen@luchendeMBP.lan> Co-authored-by: luchen <luchen@luchendeMacBook-Pro.local>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Union
|
||||
import os
|
||||
import json
|
||||
|
||||
from .utils import is_dtensor_checkpoint
|
||||
|
||||
@@ -18,8 +20,8 @@ class CheckpointIndexFile:
|
||||
>>> index.export('new_index.json')
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.root_path = None
|
||||
def __init__(self, root_path=None) -> None:
|
||||
self.root_path = root_path
|
||||
self.metadata: dict = dict()
|
||||
self.weight_map: dict = dict()
|
||||
|
||||
@@ -154,3 +156,13 @@ class CheckpointIndexFile:
|
||||
Get all the weight keys.
|
||||
"""
|
||||
return list(self.weight_map.keys())
|
||||
|
||||
def write_index_file(self, save_index_file):
|
||||
"""
|
||||
Wriete index file.
|
||||
"""
|
||||
save_index_file = os.path.join(self.root_path, save_index_file)
|
||||
index = {"metadata": self.metadata, "weight_map": self.weight_map}
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
Reference in New Issue
Block a user