This commit is contained in:
wangbluo 2024-11-25 10:51:49 +08:00
parent fa0318dba5
commit b83143ee72

View File

@ -23,7 +23,7 @@ from colossalai.tensor.padded_tensor import (
to_unpadded_tensor,
)
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from colossalai.utils.safetensors import move_and_save
from colossalai.utils.safetensors import save
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@ -708,7 +708,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
self.async_writers.append(writer)
move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
else:
save_state_dict(complete_state_dict, checkpoint, use_safetensors)