mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from torch import nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode, seed
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.base_layer import ParallelLayer
|
||||
@@ -12,7 +13,6 @@ from colossalai.legacy.nn.layer.parallel_1d._utils import gather_forward_split_b
|
||||
from colossalai.legacy.nn.layer.parallel_1d.layers import Linear1D_Row
|
||||
from colossalai.legacy.nn.layer.utils import divide
|
||||
from colossalai.legacy.registry import LAYERS, LOSSES
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
@@ -96,7 +96,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = torch.arange(
|
||||
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
@@ -194,7 +196,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(self.num_embeddings_per_partition, self.embedding_dim, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
@@ -439,7 +441,9 @@ class HiddenParallelEmbedding(torch.nn.Module):
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.view(-1, input_shape[-1])
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(0, input_shape[-1] + 0, dtype=torch.long, device=get_current_device())
|
||||
position_ids = torch.arange(
|
||||
0, input_shape[-1] + 0, dtype=torch.long, device=get_accelerator().get_current_device()
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
@@ -532,7 +536,7 @@ class HiddenParallelEmbedding1D(torch.nn.Module):
|
||||
self._weight = None
|
||||
|
||||
# Allocate weights and initialize.
|
||||
factory_kwargs = {"device": get_current_device(), "dtype": dtype}
|
||||
factory_kwargs = {"device": get_accelerator().get_current_device(), "dtype": dtype}
|
||||
self.weight = Parameter(torch.empty(num_embeddings, embed_dim_per_partition, **factory_kwargs))
|
||||
init.uniform_(self.weight, -1, 1)
|
||||
|
||||
|
Reference in New Issue
Block a user