mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-08 11:27:24 +00:00
fix
This commit is contained in:
parent
fa0318dba5
commit
b83143ee72
@ -23,7 +23,7 @@ from colossalai.tensor.padded_tensor import (
|
|||||||
to_unpadded_tensor,
|
to_unpadded_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||||
from colossalai.utils.safetensors import move_and_save
|
from colossalai.utils.safetensors import save
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
@ -708,7 +708,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
|
||||||
self.async_writers.append(writer)
|
self.async_writers.append(writer)
|
||||||
move_and_save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
|
save(writer, complete_state_dict, self.pinned_state_dicts[id(model)])
|
||||||
else:
|
else:
|
||||||
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
save_state_dict(complete_state_dict, checkpoint, use_safetensors)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user