[checkpointio] support non blocking pin load (#6172)

* [checkpointio] support non blocking pin load

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-12-25 17:03:25 +08:00
committed by GitHub
parent 836992438f
commit af06d162cf
15 changed files with 484 additions and 174 deletions

View File

@@ -1,18 +1,20 @@
# coding=utf-8
import concurrent.futures
import os
import re
from collections import abc as container_abcs
from collections import defaultdict
from itertools import chain
from pathlib import Path
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple
from typing import Dict, Iterator, List, Mapping, Optional, OrderedDict, Tuple, Union
import torch
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from colossalai.accelerator import get_accelerator
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@@ -791,7 +793,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
if key != "step":
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
value = value.to(param.device, non_blocking=True)
return value
elif isinstance(value, dict):
return {k: cast(param, v, key=k) for k, v in value.items()}
@@ -811,6 +813,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
elif not strict:
new_states[k] = v
get_accelerator().synchronize()
optimizer.state.update(new_states)
@@ -945,8 +948,27 @@ def get_shard_filename(weights_name: str, idx: int):
return shard_file
def create_pinned_state_dict(state_dict: Dict[str, torch.Tensor]):
pin_mem = dict()
for name, tensor in state_dict.items():
pin_mem[name] = torch.empty_like(tensor, pin_memory=True, device="cpu")
return pin_mem
def _pin_tensor(tensor: torch.Tensor, empty: bool = True) -> torch.Tensor:
if empty:
return torch.empty_like(tensor, pin_memory=True, device="cpu")
return tensor.pin_memory()
def create_pinned_state_dict(
state_dict: Union[Dict[str, torch.Tensor], Dict[int, Dict[str, torch.Tensor]]],
empty: bool = True,
num_threads: int = 1,
) -> Dict[str, torch.Tensor]:
if num_threads == 1:
return tree_map(lambda x: _pin_tensor(x, empty=empty) if isinstance(x, torch.Tensor) else x, state_dict)
else:
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
elems, spec = tree_flatten(state_dict)
future_to_idx = {}
for i, elem in enumerate(elems):
if isinstance(elem, torch.Tensor):
future_to_idx[executor.submit(_pin_tensor, elem, empty)] = i
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
elems[idx] = future.result()
return tree_unflatten(elems, spec)