[npu] add npu support for hybrid plugin and llama (#5090)

* llama 3d

* update

* fix autocast
This commit is contained in:
Xuanlei Zhao
2023-11-22 19:23:21 +08:00
committed by GitHub
parent aae496631c
commit 3acbf6d496
9 changed files with 61 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Callable
import torch
import torch.distributed as dist
@@ -191,6 +191,10 @@ def reset_max_memory_allocated(device=None) -> None:
return _dispatch_device_func("reset_max_memory_allocated", device)
def reset_max_memory_cached(device=None) -> None:
return _dispatch_device_func("reset_max_memory_cached", device)
def memory_reserved(device=None) -> int:
return _dispatch_device_func("memory_reserved", device)
@@ -205,3 +209,15 @@ def set_per_process_memory_fraction(fraction: float, device=None) -> None:
def reset_peak_memory_stats(device=None) -> None:
return _dispatch_device_func("reset_peak_memory_stats", device)
# amp
def autocast() -> Callable:
if torch.cuda.is_available():
return torch.cuda.amp.autocast()
elif IS_NPU_AVAILABLE:
return torch.npu.amp.autocast()
else:
raise RuntimeError("No device available")