[checkpointio] hotfix torch 2.0 compatibility (#4824)

This commit is contained in:
Hongxin Liu
2023-10-07 10:45:52 +08:00
committed by GitHub
parent ad23460cf8
commit cb3a25a062
3 changed files with 26 additions and 12 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.nn import Parameter
from torch.optim import Optimizer
@@ -676,7 +677,10 @@ class GeminiOptimizer(OptimizerWrapper):
def optimizer_loading_epilogue(self):
# Epilogue when loading state_dict to pytorch optimizer.
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
if Version(torch.__version__) >= Version("2.0.0"):
self.optim._patch_step_function() # To support multiprocessing pickle/unpickle
else:
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
self.optim.defaults.setdefault("differentiable", False)
def load_state_dict(self, state_dict: dict):