[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -3,10 +3,10 @@ import torch
import torch.distributed as dist
import colossalai
from colossalai.accelerator import get_accelerator
from colossalai.moe import SparseMLP
from colossalai.moe.manager import MOE_MANAGER
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
BATCH_SIZE = 4
NUM_EXPERTS = 4
@@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
torch.manual_seed(rs + local_rank) # set each process has different random seed
# get randomized data
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
tokens = torch.randn(
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
)
layer = SparseMLP(
hidden_size=hidden_size,
@@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
router_top_k=topk,
router_capacity_factor_train=1.0,
)
layer = layer.to(get_current_device())
layer = layer.to(get_accelerator().get_current_device())
if data_type == torch.float16:
layer = layer.half()
@@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
layer.enable_kernel = False
old_out = layer(tokens)
ech = old_out.shape
grad = torch.randn(ech, device=get_current_device())
grad = torch.randn(ech, device=get_accelerator().get_current_device())
old_out.backward(grad) # get gradient
# save all results