mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[async io]supoort async io (#6137)
* support async optimizer save/load * fix * fix * support pin mem * Update low_level_zero_plugin.py * fix * fix * fix * fix * fix
This commit is contained in:
committed by
Hongxin Liu
parent
b90835bd32
commit
eb69e640e5
@@ -24,9 +24,11 @@ from colossalai.utils.safetensors import move_and_save
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
STATES_NAME = "pytorch_optim.bin"
|
||||
SAFE_STATE_NAME = "optimizer.safetensors"
|
||||
SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
STATES_INDEX_NAME = "pytorch_optim.bin.index.json"
|
||||
SAFE_STATES_INDEX_NAME = "optimizer.safetensors.index.json"
|
||||
GROUP_FILE_NAME = "pytorch_optim_group.bin"
|
||||
|
||||
# ======================================
|
||||
@@ -838,14 +840,14 @@ def get_model_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
return weights_name, save_index_file
|
||||
|
||||
|
||||
def get_optimizer_base_filenames(prefix: str = None):
|
||||
def get_optimizer_base_filenames(prefix: str = None, use_safetensors: bool = False):
|
||||
"""
|
||||
generate base optimizer state filenames
|
||||
"""
|
||||
states_name = STATES_NAME
|
||||
states_name = SAFE_STATE_NAME if use_safetensors else STATES_NAME
|
||||
states_name = add_prefix(states_name, prefix)
|
||||
|
||||
save_index_file = STATES_INDEX_NAME
|
||||
save_index_file = SAFE_STATES_INDEX_NAME if use_safetensors else STATES_INDEX_NAME
|
||||
save_index_file = add_prefix(save_index_file, prefix)
|
||||
|
||||
param_group_file = GROUP_FILE_NAME
|
||||
|
Reference in New Issue
Block a user