From b83143ee72b8b694ded965465fec70ac339a4d9f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Mon, 25 Nov 2024 10:51:49 +0800 Subject: [PATCH] fix --- colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 63b8dde98..499f21d1b 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -23,7 +23,7 @@ from colossalai.tensor.padded_tensor import ( to_unpadded_tensor, ) from colossalai.utils import get_current_device, get_non_persistent_buffers_set -from colossalai.utils.safetensors import move_and_save +from colossalai.utils.safetensors import save from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -708,7 +708,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if id(model) not in self.pinned_state_dicts: self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) self.async_writers.append(writer) - move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)]) + save(writer, complete_state_dict, self.pinned_state_dicts[id(model)]) else: save_state_dict(complete_state_dict, checkpoint, use_safetensors)