[misc] fit torch api upgradation and remove legecy import (#6093)

* [amp] fit torch's new api

* [amp] fix api call

* [amp] fix api call

* [misc] fit torch pytree api upgrade

* [misc] remove legacy import

* [misc] fit torch amp api

* [misc] fit torch amp api
This commit is contained in:
Hongxin Liu
2024-10-18 16:48:52 +08:00
committed by GitHub
parent 5ddad486ca
commit 58d8b8a2dd
7 changed files with 20 additions and 12 deletions

View File

@@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
"""
Return autocast function
"""
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)