[npu] change device to accelerator api (#5239)

* update accelerator

* fix timer

* fix amp

* update

* fix

* update bug

* add error raise

* fix autocast

* fix set device

* remove doc accelerator

* update doc

* update doc

* update doc

* use nullcontext

* update cpu

* update null context

* change time limit for example

* udpate

* update

* update

* update

* [npu] polish accelerator code

---------

Co-authored-by: Xuanlei Zhao <xuanlei.zhao@gmail.com>
Co-authored-by: zxl <43881818+oahzxl@users.noreply.github.com>
This commit is contained in:
Hongxin Liu
2024-01-09 10:20:05 +08:00
committed by GitHub
parent dd2c28a323
commit d202cc28c0
128 changed files with 1773 additions and 868 deletions

View File

@@ -2,6 +2,7 @@ import torch
import torch.distributed as dist
from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.global_variables import tensor_parallel_env as env
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
VocabParallelEmbedding1D,
)
from colossalai.legacy.utils import print_rank_0
from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear_col():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -68,7 +68,7 @@ def check_linear_col():
print_rank_0("linear_col forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
dist.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=-1)[i]
grad = grad.clone()
@@ -91,7 +91,7 @@ def check_linear_col():
def check_linear_row():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -137,7 +137,7 @@ def check_linear_row():
print_rank_0("linear_row forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
dist.broadcast(grad_master, src=0)
grad = grad_master.clone()
out.backward(grad)
@@ -159,7 +159,7 @@ def check_linear_row():
def check_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -201,7 +201,7 @@ def check_embed():
def check_vocab_parallel_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -243,7 +243,7 @@ def check_vocab_parallel_embed():
def check_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -309,7 +309,7 @@ def check_classifier_no_given_weight():
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -369,7 +369,7 @@ def check_vocab_parallel_classifier_no_given_weight():
def check_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -420,7 +420,7 @@ def check_classifier_given_embed_weight():
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
def check_vocab_parallel_loss():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@@ -508,7 +508,7 @@ def check_vocab_parallel_loss():
@torch.no_grad()
def check_linear_row_stream_inference():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE

View File

@@ -1,5 +1,6 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import (
@@ -16,13 +17,12 @@ from colossalai.legacy.nn import (
VocabParallelEmbedding2D,
)
from colossalai.legacy.utils import print_rank_0
from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
def check_linear():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = HIDDEN_SIZE
@@ -74,7 +74,7 @@ def check_linear():
print_rank_0("linear forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -103,7 +103,7 @@ def check_linear():
def check_layernorm():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
@@ -139,7 +139,7 @@ def check_layernorm():
print_rank_0("layer norm forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -154,7 +154,7 @@ def check_layernorm():
def check_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -201,7 +201,7 @@ def check_embed():
def check_patch_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -274,7 +274,7 @@ def check_patch_embed():
def check_vocab_parallel_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -321,7 +321,7 @@ def check_vocab_parallel_embed():
def check_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
@@ -371,7 +371,7 @@ def check_classifier_no_given_weight():
print_rank_0("classifier (no given weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
# grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -399,7 +399,7 @@ def check_classifier_no_given_weight():
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -467,7 +467,7 @@ def check_vocab_parallel_classifier_no_given_weight():
def check_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -519,7 +519,7 @@ def check_classifier_given_embed_weight():
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -573,7 +573,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
def check_loss():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -608,7 +608,7 @@ def check_loss():
def check_vocab_parallel_loss():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
@@ -645,7 +645,7 @@ def check_vocab_parallel_loss():
# def check_attention():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
@@ -683,7 +683,7 @@ def check_vocab_parallel_loss():
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
@@ -716,7 +716,7 @@ def check_vocab_parallel_loss():
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2

View File

@@ -3,11 +3,11 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2d._operation import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
from colossalai.legacy.utils import print_rank_0
from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, SEQ_LENGTH, check_equal
@@ -27,7 +27,7 @@ def check_AB():
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, DEPTH, dim=0)[i]
A = torch.chunk(A, DEPTH, dim=-1)[j]
@@ -35,7 +35,7 @@ def check_AB():
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j]
@@ -72,7 +72,7 @@ def check_AB():
print_rank_0("AB forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -105,7 +105,7 @@ def check_ABT():
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
device = get_accelerator().get_current_device()
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
@@ -184,7 +184,7 @@ def check_ATB():
)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)

View File

@@ -1,6 +1,7 @@
import torch
from torch.nn import Parameter
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import (
@@ -17,13 +18,12 @@ from colossalai.legacy.nn import (
VocabParallelEmbedding2p5D,
)
from colossalai.legacy.utils import print_rank_0
from colossalai.utils import get_current_device
from .common import *
def check_linear():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -76,7 +76,7 @@ def check_linear():
print_rank_0("linear forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -104,7 +104,7 @@ def check_linear():
def check_layernorm():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
EPS = 1e-12
@@ -141,7 +141,7 @@ def check_layernorm():
print_rank_0("layer norm forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -156,7 +156,7 @@ def check_layernorm():
def check_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -204,7 +204,7 @@ def check_embed():
def check_patch_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -278,7 +278,7 @@ def check_patch_embed():
def check_vocab_parallel_embed():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -326,7 +326,7 @@ def check_vocab_parallel_embed():
def check_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = NUM_CLASSES
@@ -377,7 +377,7 @@ def check_classifier_no_given_weight():
print_rank_0("classifier (no given weight) forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
# grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -405,7 +405,7 @@ def check_classifier_no_given_weight():
def check_vocab_parallel_classifier_no_given_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -472,7 +472,7 @@ def check_vocab_parallel_classifier_no_given_weight():
def check_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -524,7 +524,7 @@ def check_classifier_given_embed_weight():
def check_vocab_parallel_classifier_given_embed_weight():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -578,7 +578,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
def check_loss():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -613,7 +613,7 @@ def check_loss():
def check_vocab_parallel_loss():
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -650,7 +650,7 @@ def check_vocab_parallel_loss():
# def check_attention():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2
@@ -689,7 +689,7 @@ def check_vocab_parallel_loss():
# print_rank_0('self attention backward: pass')
# def check_mlp():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
@@ -725,7 +725,7 @@ def check_vocab_parallel_loss():
# print_rank_0('mlp backward: pass')
# def check_transformerlayer():
# device = get_current_device()
# device = get_accelerator().get_current_device()
# dtype = torch.float32
# INPUT_SIZE = HIDDEN_SIZE
# NUM_ATTENTION_HEADS = 2

View File

@@ -1,10 +1,10 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn.layer.parallel_2p5d._operation import Matmul_AB_2p5D, Matmul_ABT_2p5D, Matmul_ATB_2p5D
from colossalai.legacy.utils import print_rank_0
from colossalai.utils import get_current_device
from .common import *
@@ -25,7 +25,7 @@ def check_AB():
k = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
A_master = torch.randn(A_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(A_master, src=0)
A = torch.chunk(A_master, TESSERACT_DIM, dim=0)[i]
A = torch.chunk(A, TESSERACT_DIM, dim=-1)[j]
@@ -33,7 +33,7 @@ def check_AB():
A.requires_grad = True
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
B_master = torch.randn(B_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(B_master, src=0)
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i]
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j]
@@ -70,7 +70,7 @@ def check_AB():
print_rank_0("AB forward: pass")
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, TESSERACT_DIM, dim=0)[i]
grad = torch.chunk(grad, TESSERACT_DIM, dim=-1)[j]
@@ -103,7 +103,7 @@ def check_ABT():
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
dtype = torch.float
device = get_current_device()
device = get_accelerator().get_current_device()
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
j = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
@@ -184,7 +184,7 @@ def check_ATB():
)
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float
i = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)

View File

@@ -5,6 +5,7 @@ import time
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.constants import INPUT_GROUP_3D, OUTPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.legacy.core import global_context
from colossalai.legacy.nn import (
@@ -23,7 +24,6 @@ from colossalai.legacy.nn import (
from colossalai.legacy.nn.layer.parallel_3d._utils import get_parallel_mode_from_env
from colossalai.legacy.utils import print_rank_0
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal
@@ -31,7 +31,7 @@ from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, IMG_SIZE, NUM_CLASSES, SEQ_L
def check_linear():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
INPUT_SIZE = HIDDEN_SIZE
OUTPUT_SIZE = 2 * HIDDEN_SIZE
@@ -84,7 +84,7 @@ def check_linear():
logger.info("Rank {} linear forward: {}".format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
@@ -119,7 +119,7 @@ def check_linear():
def check_layernorm():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -206,7 +206,7 @@ def check_layernorm():
def check_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -258,7 +258,7 @@ def check_classifier_no_given_weight():
logger.info("Rank {} classifier (no given weight) forward: {}".format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, device=get_current_device())
grad_master = torch.randn(grad_shape, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
@@ -306,7 +306,7 @@ def check_classifier_no_given_weight():
def check_vocab_parallel_classifier_no_given_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
INPUT_SIZE = HIDDEN_SIZE
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -413,7 +413,7 @@ def check_vocab_parallel_classifier_no_given_weight():
def check_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
dtype = torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
@@ -463,7 +463,7 @@ def check_classifier_given_embed_weight():
logger.info("Rank {} classifier (given embed weight) forward: {}".format(rank, check_equal(out, C)))
grad_shape = C_master.shape
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_accelerator().get_current_device())
torch.distributed.broadcast(grad_master, src=0)
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
grad = torch.chunk(grad, DEPTH, dim=0)[j]
@@ -497,7 +497,7 @@ def check_classifier_given_embed_weight():
def check_vocab_parallel_classifier_given_embed_weight():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -580,7 +580,7 @@ def check_vocab_parallel_classifier_given_embed_weight():
def check_patch_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
device = get_accelerator().get_current_device()
logger = get_dist_logger()
torch.float32
@@ -678,7 +678,7 @@ def check_patch_embed():
def check_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
device = get_accelerator().get_current_device()
logger = get_dist_logger()
torch.float32
@@ -746,7 +746,7 @@ def check_embed():
def check_vocab_parallel_embed():
rank = torch.distributed.get_rank()
device = get_current_device()
device = get_accelerator().get_current_device()
logger = get_dist_logger()
torch.float32
@@ -823,7 +823,7 @@ def check_vocab_parallel_embed():
def check_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
@@ -876,7 +876,7 @@ def check_loss():
def check_vocab_parallel_loss():
rank = torch.distributed.get_rank()
logger = get_dist_logger()
device = get_current_device()
device = get_accelerator().get_current_device()
torch.float32
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)

View File

@@ -1,9 +1,9 @@
import torch
from colossalai.accelerator import get_accelerator
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.nn import TransformerSelfAttentionRing
from colossalai.utils import get_current_device
def check_selfattention():
@@ -13,10 +13,10 @@ def check_selfattention():
HIDDEN_SIZE = 16
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
layer = layer.to(get_current_device())
layer = layer.to(get_accelerator().get_current_device())
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_accelerator().get_current_device())
attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
get_current_device()
get_accelerator().get_current_device()
)
layer(hidden_states, attention_mask)