diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 63b8dde98..499f21d1b 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -23,7 +23,7 @@ from colossalai.tensor.padded_tensor import ( to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set -from colossalai.utils.safetensors import move_and_save +from colossalai.utils.safetensors import save from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -708,7 +708,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) self.async_writers.append(writer) - move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)]) + save(writer, complete_state_dict, self.pinned_state_dicts[id(model)]) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors)