[checkpointio] hotfix torch 2.0 compatibility (#4824)

This commit is contained in:
Hongxin Liu
2023-10-07 10:45:52 +08:00
committed by GitHub
parent ad23460cf8
commit cb3a25a062
3 changed files with 26 additions and 12 deletions

View File

@@ -9,6 +9,7 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from colossalai.tensor.d_tensor import (
@@ -663,7 +664,10 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
"""
# Do the cleaning up as in src code of Pytorch.
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
if Version(torch.__version__) >= Version("2.0.0"):
optimizer._patch_step_function() # To support multiprocessing pickle/unpickle
else:
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
optimizer.defaults.setdefault("differentiable", False)