mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[checkpointio] hotfix torch 2.0 compatibility (#4824)
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.tensor.d_tensor import (
|
||||
@@ -663,7 +664,10 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
|
||||
"""
|
||||
|
||||
# Do the cleaning up as in src code of Pytorch.
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
if Version(torch.__version__) >= Version("2.0.0"):
|
||||
optimizer._patch_step_function() # To support multiprocessing pickle/unpickle
|
||||
else:
|
||||
optimizer._hook_for_profile() # To support multiprocessing pickle/unpickle.
|
||||
optimizer.defaults.setdefault("differentiable", False)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user