mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 02:51:59 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -3,10 +3,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.moe import SparseMLP
|
||||
from colossalai.moe.manager import MOE_MANAGER
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_EXPERTS = 4
|
||||
@@ -28,7 +28,9 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
torch.manual_seed(rs + local_rank) # set each process has different random seed
|
||||
|
||||
# get randomized data
|
||||
tokens = torch.randn(BATCH_SIZE, hidden_size, dtype=data_type, device=get_current_device(), requires_grad=True)
|
||||
tokens = torch.randn(
|
||||
BATCH_SIZE, hidden_size, dtype=data_type, device=get_accelerator().get_current_device(), requires_grad=True
|
||||
)
|
||||
|
||||
layer = SparseMLP(
|
||||
hidden_size=hidden_size,
|
||||
@@ -37,7 +39,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
router_top_k=topk,
|
||||
router_capacity_factor_train=1.0,
|
||||
)
|
||||
layer = layer.to(get_current_device())
|
||||
layer = layer.to(get_accelerator().get_current_device())
|
||||
if data_type == torch.float16:
|
||||
layer = layer.half()
|
||||
|
||||
@@ -45,7 +47,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
|
||||
layer.enable_kernel = False
|
||||
old_out = layer(tokens)
|
||||
ech = old_out.shape
|
||||
grad = torch.randn(ech, device=get_current_device())
|
||||
grad = torch.randn(ech, device=get_accelerator().get_current_device())
|
||||
old_out.backward(grad) # get gradient
|
||||
|
||||
# save all results
|
||||
|
Reference in New Issue
Block a user