[misc] fit torch api upgradation and remove legecy import (#6093)

* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
This commit is contained in:
Hongxin Liu
2024-10-18 16:48:52 +08:00
committed by GitHub
parent 5ddad486ca
commit 58d8b8a2dd
7 changed files with 20 additions and 12 deletions

View File

@@ -3,8 +3,9 @@ from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from packaging.version import Version
from torch.nn import Module
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, _register_pytree_node, tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import SUPPORTED_NODES, TreeSpec, tree_flatten, tree_map, tree_unflatten
# this register are for torch under version 1.13.1, maybe removed in the future
@@ -16,7 +17,12 @@ def _odict_unflatten(values: List[Any], context: Any) -> "OrderedDict[Any, Any]"
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
if Version(torch.__version__) <= Version("1.13.1"):
try:
from torch.utils._pytree import register_pytree_node as _register_pytree_node
except ImportError:
from torch.utils._pytree import _register_pytree_node
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):