mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-06 03:20:52 +00:00
[checkpointio] hotfix torch 2.0 compatibility (#4824)
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any, Dict, Iterator, OrderedDict, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging.version import Version
|
||||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
@@ -676,7 +677,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
||||
|
||||
def optimizer_loading_epilogue(self):
|
||||
# Epilogue when loading state_dict to pytorch optimizer.
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
if Version(torch.__version__) >= Version("2.0.0"):
|
||||
self.optim._patch_step_function() # To support multiprocessing pickle/unpickle
|
||||
else:
|
||||
self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
self.optim.defaults.setdefault("differentiable", False)
|
||||
|
||||
def load_state_dict(self, state_dict: dict):
|
||||
|
Reference in New Issue
Block a user