[checkpointio] fix hybrid plugin model save (#6106)

This commit is contained in:
Hongxin Liu
2024-10-31 17:04:53 +08:00
committed by GitHub
parent 89a9a600bc
commit c2e8f61592
4 changed files with 41 additions and 38 deletions

View File

@@ -21,7 +21,7 @@ from colossalai.tensor.padded_tensor import (
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@@ -105,8 +105,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
yield block, block_size
# Save buffers.
non_persist_buffers_set = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
if buf is not None and name not in non_persist_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
@@ -352,9 +353,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
non_persistent_buffers = get_non_persistent_buffers_set(model)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)