mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-05-03 14:08:11 +00:00
[checkpointio] fix async io (#6155)
This commit is contained in:
parent
de3d371f65
commit
e994c64568
colossalai/checkpoint_io
@ -8,8 +8,6 @@ from typing import Optional
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.utils.safetensors import move_and_save
|
|
||||||
|
|
||||||
from .checkpoint_io_base import CheckpointIO
|
from .checkpoint_io_base import CheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -54,6 +52,7 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if use_async:
|
if use_async:
|
||||||
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
if id(model) not in self.pinned_state_dicts:
|
if id(model) not in self.pinned_state_dicts:
|
||||||
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
|
||||||
|
@ -19,7 +19,6 @@ from colossalai.tensor.d_tensor import (
|
|||||||
to_global,
|
to_global,
|
||||||
to_global_for_customized_distributed_tensor,
|
to_global_for_customized_distributed_tensor,
|
||||||
)
|
)
|
||||||
from colossalai.utils.safetensors import move_and_save
|
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
@ -289,6 +288,7 @@ def async_save_state_dict_shards(
|
|||||||
Returns:
|
Returns:
|
||||||
int: the total size of shards
|
int: the total size of shards
|
||||||
"""
|
"""
|
||||||
|
from colossalai.utils.safetensors import move_and_save
|
||||||
|
|
||||||
total_size = 0
|
total_size = 0
|
||||||
shard_filenames = []
|
shard_filenames = []
|
||||||
|
Loading…
Reference in New Issue
Block a user