[checkpointio] fix async io (#6155)

This commit is contained in:
flybird11111 2024-12-16 10:36:28 +08:00 committed by GitHub
parent de3d371f65
commit e994c64568
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 2 additions and 3 deletions

View File

@ -8,8 +8,6 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils.safetensors import move_and_save
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
@ -54,6 +52,7 @@ class GeneralCheckpointIO(CheckpointIO):
pass
if use_async:
from colossalai.utils.safetensors import move_and_save
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)

View File

@ -19,7 +19,6 @@ from colossalai.tensor.d_tensor import (
to_global,
to_global_for_customized_distributed_tensor,
)
from colossalai.utils.safetensors import move_and_save
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -289,6 +288,7 @@ def async_save_state_dict_shards(
Returns:
int: the total size of shards
"""
from colossalai.utils.safetensors import move_and_save
total_size = 0
shard_filenames = []