mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[checkpointio] fix hybrid plugin model save (#6106)
This commit is contained in:
@@ -21,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
@@ -105,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
yield block, block_size
|
||||
|
||||
# Save buffers.
|
||||
non_persist_buffers_set = get_non_persistent_buffers_set(model)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in model._non_persistent_buffers_set:
|
||||
if buf is not None and name not in non_persist_buffers_set:
|
||||
buffer = buf if keep_vars else buf.detach()
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
|
||||
if block is not None:
|
||||
@@ -352,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||
_load(name)
|
||||
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
non_persistent_buffers = get_non_persistent_buffers_set(model)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
Reference in New Issue
Block a user