mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 14:12:02 +00:00
[npu] add npu support for hybrid plugin and llama (#5090)
* llama 3d * update * fix autocast
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
|
||||
return _dispatch_device_func("reset_max_memory_allocated", device)
|
||||
|
||||
|
||||
def reset_max_memory_cached(device=None) -> None:
|
||||
return _dispatch_device_func("reset_max_memory_cached", device)
|
||||
|
||||
|
||||
def memory_reserved(device=None) -> int:
|
||||
return _dispatch_device_func("memory_reserved", device)
|
||||
|
||||
@@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
|
||||
|
||||
def reset_peak_memory_stats(device=None) -> None:
|
||||
return _dispatch_device_func("reset_peak_memory_stats", device)
|
||||
|
||||
|
||||
# amp
|
||||
|
||||
|
||||
def autocast() -> Callable:
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.amp.autocast()
|
||||
elif IS_NPU_AVAILABLE:
|
||||
return torch.npu.amp.autocast()
|
||||
else:
|
||||
raise RuntimeError("No device available")
|
||||
|
Reference in New Issue
Block a user