mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-16 14:41:53 +00:00
[npu] change device to accelerator api (#5239)
* update accelerator * fix timer * fix amp * update * fix * update bug * add error raise * fix autocast * fix set device * remove doc accelerator * update doc * update doc * update doc * use nullcontext * update cpu * update null context * change time limit for example * udpate * update * update * update * [npu] polish accelerator code --------- Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com> Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.global_variables import tensor_parallel_env as env
|
||||
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
|
||||
VocabParallelEmbedding1D,
|
||||
)
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear_col():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
@@ -68,7 +68,7 @@ def check_linear_col():
|
||||
print_rank_0("linear_col forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
|
||||
grad = grad.clone()
|
||||
@@ -91,7 +91,7 @@ def check_linear_col():
|
||||
|
||||
|
||||
def check_linear_row():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
@@ -137,7 +137,7 @@ def check_linear_row():
|
||||
print_rank_0("linear_row forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
dist.broadcast(grad_master, src=0)
|
||||
grad = grad_master.clone()
|
||||
out.backward(grad)
|
||||
@@ -159,7 +159,7 @@ def check_linear_row():
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -201,7 +201,7 @@ def check_embed():
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -243,7 +243,7 @@ def check_vocab_parallel_embed():
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -309,7 +309,7 @@ def check_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -420,7 +420,7 @@ def check_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
@@ -508,7 +508,7 @@ def check_vocab_parallel_loss():
|
||||
|
||||
@torch.no_grad()
|
||||
def check_linear_row_stream_inference():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn import (
|
||||
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
|
||||
VocabParallelEmbedding2D,
|
||||
)
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
|
||||
def check_linear():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = HIDDEN_SIZE
|
||||
@@ -74,7 +74,7 @@ def check_linear():
|
||||
print_rank_0("linear forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -103,7 +103,7 @@ def check_linear():
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
EPS = 1e-12
|
||||
@@ -139,7 +139,7 @@ def check_layernorm():
|
||||
print_rank_0("layer norm forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -154,7 +154,7 @@ def check_layernorm():
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
@@ -201,7 +201,7 @@ def check_embed():
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
@@ -274,7 +274,7 @@ def check_patch_embed():
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
@@ -321,7 +321,7 @@ def check_vocab_parallel_embed():
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
@@ -371,7 +371,7 @@ def check_classifier_no_given_weight():
|
||||
print_rank_0("classifier (no given weight) forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -399,7 +399,7 @@ def check_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
@@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
@@ -519,7 +519,7 @@ def check_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
@@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
@@ -608,7 +608,7 @@ def check_loss():
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
@@ -645,7 +645,7 @@ def check_vocab_parallel_loss():
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
@@ -683,7 +683,7 @@ def check_vocab_parallel_loss():
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
@@ -716,7 +716,7 @@ def check_vocab_parallel_loss():
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
@@ -3,11 +3,11 @@
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
|
||||
|
||||
@@ -27,7 +27,7 @@ def check_AB():
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||||
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||||
@@ -35,7 +35,7 @@ def check_AB():
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||||
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||||
@@ -72,7 +72,7 @@ def check_AB():
|
||||
print_rank_0("AB forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -105,7 +105,7 @@ def check_ABT():
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||||
@@ -184,7 +184,7 @@ def check_ATB():
|
||||
)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||||
|
@@ -1,6 +1,7 @@
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context.parallel_mode import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn import (
|
||||
@@ -17,13 +18,12 @@ from colossalai.legacy.nn import (
|
||||
VocabParallelEmbedding2p5D,
|
||||
)
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import *
|
||||
|
||||
|
||||
def check_linear():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
@@ -76,7 +76,7 @@ def check_linear():
|
||||
print_rank_0("linear forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -104,7 +104,7 @@ def check_linear():
|
||||
|
||||
|
||||
def check_layernorm():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
EPS = 1e-12
|
||||
@@ -141,7 +141,7 @@ def check_layernorm():
|
||||
print_rank_0("layer norm forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -156,7 +156,7 @@ def check_layernorm():
|
||||
|
||||
|
||||
def check_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -204,7 +204,7 @@ def check_embed():
|
||||
|
||||
|
||||
def check_patch_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -278,7 +278,7 @@ def check_patch_embed():
|
||||
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -326,7 +326,7 @@ def check_vocab_parallel_embed():
|
||||
|
||||
|
||||
def check_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = NUM_CLASSES
|
||||
@@ -377,7 +377,7 @@ def check_classifier_no_given_weight():
|
||||
print_rank_0("classifier (no given weight) forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -405,7 +405,7 @@ def check_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
|
||||
|
||||
def check_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -524,7 +524,7 @@ def check_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
|
||||
|
||||
def check_loss():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -613,7 +613,7 @@ def check_loss():
|
||||
|
||||
|
||||
def check_vocab_parallel_loss():
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -650,7 +650,7 @@ def check_vocab_parallel_loss():
|
||||
|
||||
|
||||
# def check_attention():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
@@ -689,7 +689,7 @@ def check_vocab_parallel_loss():
|
||||
# print_rank_0('self attention backward: pass')
|
||||
|
||||
# def check_mlp():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
@@ -725,7 +725,7 @@ def check_vocab_parallel_loss():
|
||||
# print_rank_0('mlp backward: pass')
|
||||
|
||||
# def check_transformerlayer():
|
||||
# device = get_current_device()
|
||||
# device = get_accelerator().get_current_device()
|
||||
# dtype = torch.float32
|
||||
# INPUT_SIZE = HIDDEN_SIZE
|
||||
# NUM_ATTENTION_HEADS = 2
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import *
|
||||
|
||||
@@ -25,7 +25,7 @@ def check_AB():
|
||||
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(A_master, src=0)
|
||||
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
|
||||
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -33,7 +33,7 @@ def check_AB():
|
||||
A.requires_grad = True
|
||||
|
||||
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||||
B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(B_master, src=0)
|
||||
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
|
||||
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -70,7 +70,7 @@ def check_AB():
|
||||
print_rank_0("AB forward: pass")
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
|
||||
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
|
||||
@@ -103,7 +103,7 @@ def check_ABT():
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
dtype = torch.float
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
|
||||
@@ -184,7 +184,7 @@ def check_ATB():
|
||||
)
|
||||
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float
|
||||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
|
||||
|
@@ -5,6 +5,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||
from colossalai.legacy.core import global_context
|
||||
from colossalai.legacy.nn import (
|
||||
@@ -23,7 +24,6 @@ from colossalai.legacy.nn import (
|
||||
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
|
||||
from colossalai.legacy.utils import print_rank_0
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
|
||||
|
||||
@@ -31,7 +31,7 @@ from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_L
|
||||
def check_linear():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
||||
|
||||
@@ -84,7 +84,7 @@ def check_linear():
|
||||
logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||||
@@ -119,7 +119,7 @@ def check_linear():
|
||||
def check_layernorm():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -206,7 +206,7 @@ def check_layernorm():
|
||||
def check_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -258,7 +258,7 @@ def check_classifier_no_given_weight():
|
||||
logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
@@ -306,7 +306,7 @@ def check_classifier_no_given_weight():
|
||||
def check_vocab_parallel_classifier_no_given_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
INPUT_SIZE = HIDDEN_SIZE
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight():
|
||||
def check_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
dtype = torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
@@ -463,7 +463,7 @@ def check_classifier_given_embed_weight():
|
||||
logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C)))
|
||||
|
||||
grad_shape = C_master.shape
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||||
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
|
||||
torch.distributed.broadcast(grad_master, src=0)
|
||||
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||||
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
||||
@@ -497,7 +497,7 @@ def check_classifier_given_embed_weight():
|
||||
def check_vocab_parallel_classifier_given_embed_weight():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
|
||||
|
||||
def check_patch_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
logger = get_dist_logger()
|
||||
torch.float32
|
||||
|
||||
@@ -678,7 +678,7 @@ def check_patch_embed():
|
||||
|
||||
def check_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
logger = get_dist_logger()
|
||||
torch.float32
|
||||
|
||||
@@ -746,7 +746,7 @@ def check_embed():
|
||||
|
||||
def check_vocab_parallel_embed():
|
||||
rank = torch.distributed.get_rank()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
logger = get_dist_logger()
|
||||
torch.float32
|
||||
|
||||
@@ -823,7 +823,7 @@ def check_vocab_parallel_embed():
|
||||
def check_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
|
||||
@@ -876,7 +876,7 @@ def check_loss():
|
||||
def check_vocab_parallel_loss():
|
||||
rank = torch.distributed.get_rank()
|
||||
logger = get_dist_logger()
|
||||
device = get_current_device()
|
||||
device = get_accelerator().get_current_device()
|
||||
torch.float32
|
||||
|
||||
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
|
||||
|
@@ -1,9 +1,9 @@
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.legacy.context import ParallelMode
|
||||
from colossalai.legacy.core import global_context as gpc
|
||||
from colossalai.legacy.nn import TransformerSelfAttentionRing
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
def check_selfattention():
|
||||
@@ -13,10 +13,10 @@ def check_selfattention():
|
||||
HIDDEN_SIZE = 16
|
||||
|
||||
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
|
||||
layer = layer.to(get_current_device())
|
||||
layer = layer.to(get_accelerator().get_current_device())
|
||||
|
||||
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
|
||||
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device())
|
||||
attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
|
||||
get_current_device()
|
||||
get_accelerator().get_current_device()
|
||||
)
|
||||
layer(hidden_states, attention_mask)
|
||||
|
Reference in New Issue
Block a user