mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 20:40:34 +00:00
[misc] fit torch api upgradation and remove legecy import (#6093)
* [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api
This commit is contained in:
@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
from packaging.version import Version
|
||||
from torch.nn import Module
|
||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
|
||||
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
|
||||
# this register are for torch under version 1.13.1, maybe removed in the future
|
||||
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
|
||||
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
if Version(torch.__version__) <= Version("1.13.1"):
|
||||
try:
|
||||
from torch.utils._pytree import register_pytree_node as _register_pytree_node
|
||||
except ImportError:
|
||||
from torch.utils._pytree import _register_pytree_node
|
||||
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
|
||||
|
||||
|
||||
def tree_map_hf(fn: Any, pytree: Any):
|
||||
|
Reference in New Issue
Block a user