mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[misc] fit torch api upgradation and remove legecy import (#6093)
* [amp] fit torch's new api * [amp] fix api call * [amp] fix api call * [misc] fit torch pytree api upgrade * [misc] remove legacy import * [misc] fit torch amp api * [misc] fit torch amp api
This commit is contained in:
@@ -279,4 +279,4 @@ class CudaAccelerator(BaseAccelerator):
|
||||
"""
|
||||
Return autocast function
|
||||
"""
|
||||
return torch.cuda.amp.autocast(enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||
return torch.amp.autocast(device_type="cuda", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
|
||||
|
Reference in New Issue
Block a user