mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-22 01:48:07 +00:00
[checkpointio] support non blocking pin load (#6172)
* [checkpointio] support non blocking pin load * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1,18 +1,20 @@
|
||||
# coding=utf-8
|
||||
import concurrent.futures
|
||||
import os
|
||||
import re
|
||||
from collections import abc as container_abcs
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
@@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
if key != "step":
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
value = value.to(param.device, non_blocking=True)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v, key=k) for k, v in value.items()}
|
||||
@@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
|
||||
elif not strict:
|
||||
new_states[k] = v
|
||||
|
||||
get_accelerator().synchronize()
|
||||
optimizer.state.update(new_states)
|
||||
|
||||
|
||||
@@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
|
||||
return shard_file
|
||||
|
||||
|
||||
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
|
||||
pin_mem = dict()
|
||||
for name, tensor in state_dict.items():
|
||||
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return pin_mem
|
||||
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
|
||||
if empty:
|
||||
return torch.empty_like(tensor, pin_memory=True, device="cpu")
|
||||
return tensor.pin_memory()
|
||||
|
||||
|
||||
def create_pinned_state_dict(
|
||||
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
|
||||
empty: bool = True,
|
||||
num_threads: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if num_threads == 1:
|
||||
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
|
||||
else:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
elems, spec = tree_flatten(state_dict)
|
||||
future_to_idx = {}
|
||||
for i, elem in enumerate(elems):
|
||||
if isinstance(elem, torch.Tensor):
|
||||
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
idx = future_to_idx[future]
|
||||
elems[idx] = future.result()
|
||||
return tree_unflatten(elems, spec)
|
||||
|
Reference in New Issue
Block a user