From c2947dadf1ae7d8621e6e2058463642d675e071d Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 10 Nov 2022 17:03:21 +0800 Subject: [PATCH] [inference] streaming Linear 1D Row inference (#1874) --- colossalai/nn/layer/parallel_1d/layers.py | 26 +- tests/test_fx/test_complete_workflow.py | 17 +- .../test_1d/checks_1d/check_layer_1d.py | 1045 +++++++++-------- tests/test_layers/test_1d/test_1d.py | 95 +- 4 files changed, 629 insertions(+), 554 deletions(-) diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 88ecdf691..1976da95a 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -597,9 +597,12 @@ class Linear1D_Row(ParallelLayer): parallel_input: bool = True, skip_bias_add: bool = False, weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), - bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): + bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), + stream_chunk_num: int = 1): super().__init__() + self.stream_chunk_num = stream_chunk_num + # Keep input parameters self.in_features = in_features self.out_features = out_features @@ -617,6 +620,9 @@ class Linear1D_Row(ParallelLayer): factory_kwargs = {'device': get_current_device(), 'dtype': dtype} self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + if self.stream_chunk_num > 1: + # TODO() work for inference only + self.chunk_weight() if bias: self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs)) else: @@ -626,6 +632,9 @@ class Linear1D_Row(ParallelLayer): self._set_tensor_parallel_attributes() set_parallel_input(False) + def chunk_weight(self): + self.weight_list = torch.chunk(self.weight, self.stream_chunk_num, dim=0) + def reset_parameters(self, weight_initializer, bias_initializer) -> None: fan_in, fan_out = self.in_features, self.out_features weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) @@ -696,10 +705,17 @@ class Linear1D_Row(ParallelLayer): input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) - output_parallel = F.linear(input_, self.weight) - # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) - output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) - + if self.stream_chunk_num > 1: + output_parallel_list = [None for i in range(self.stream_chunk_num)] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + output = torch.cat(output_parallel_list, dim=-1) + else: + print(input_.shape, self.weight.shape) + output_parallel = F.linear(input_, self.weight) + # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) + output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) if not self.skip_bias_add: if self.bias is not None: output = output + self.bias diff --git a/tests/test_fx/test_complete_workflow.py b/tests/test_fx/test_complete_workflow.py index 1d51e0a52..bb1a66812 100644 --- a/tests/test_fx/test_complete_workflow.py +++ b/tests/test_fx/test_complete_workflow.py @@ -32,7 +32,7 @@ class MLP(torch.nn.Module): return x -def run_workflow(world_size): +def run_workflow(world_size, dev): # initailization with LazyInitContext() as ctx: model = MLP(16) @@ -46,7 +46,7 @@ def run_workflow(world_size): gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) # annotate - annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup()) + annotated_gm = transformer_mlp_pass(gm, process_group=ProcessGroup(tp_degree=world_size)) annotated_gm.recompile() # materialization and sharding @@ -61,22 +61,25 @@ def run_workflow(world_size): # test forward to make sure that IR transform will produce the same results # like how ColoTensor would do it normally - data = torch.rand(4, 16) + data = torch.rand(4, 16, device=dev) non_fx_out = model(data) fx_out = annotated_gm(data) assert torch.equal(non_fx_out, fx_out), f'{non_fx_out} vs {fx_out}' -def run_dist(rank, world_size, port): +def run_dist(rank, world_size, dev, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_workflow(world_size) + run_workflow(world_size, dev) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) +@pytest.mark.parametrize('dev', ['cuda', 'cpu']) @rerun_if_address_is_in_use() -def test_complete_workflow(world_size): - run_func = partial(run_dist, world_size=world_size, port=free_port()) +def test_complete_workflow(world_size, dev): + if dev == 'cpu' and world_size > 1: + return + run_func = partial(run_dist, world_size=world_size, dev=dev, port=free_port()) mp.spawn(run_func, nprocs=world_size) diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py index 5e1681da9..7d77391ea 100644 --- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -1,496 +1,549 @@ -import torch -import torch.distributed as dist -from colossalai.context.parallel_mode import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.global_variables import tensor_parallel_env as env -from colossalai.nn import (Classifier1D, Embedding1D, Linear1D_Col, Linear1D_Row, VanillaClassifier, - VocabParallelClassifier1D, VocabParallelCrossEntropyLoss1D, VocabParallelEmbedding1D) -from colossalai.utils import get_current_device, print_rank_0 -from torch.nn import Parameter - -from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal - - -def check_linear_col(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = 2 * HIDDEN_SIZE - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = A_master.clone() - A.requires_grad = True - - W_shape = (OUTPUT_SIZE, INPUT_SIZE) - W_master = torch.randn(W_shape, dtype=dtype, device=device) - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=0)[i] - W = W.clone() - W.requires_grad = True - - B_shape = (OUTPUT_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - dist.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - B = B.clone() - B.requires_grad = True - - layer.weight = Parameter(W) - layer.bias = Parameter(B) - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - - check_equal(out, C) - print_rank_0('linear_col forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, layer.bias.grad) - - print_rank_0('linear_col backward: pass') - - -def check_linear_row(): - device = get_current_device() - dtype = torch.float32 - INPUT_SIZE = HIDDEN_SIZE - OUTPUT_SIZE = 2 * HIDDEN_SIZE - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = torch.chunk(A_master, DEPTH, dim=-1)[i] - A = A.clone() - A.requires_grad = True - - W_shape = (INPUT_SIZE, OUTPUT_SIZE) - W_master = torch.randn(W_shape, dtype=dtype, device=device) - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=-1)[i] - W = W.clone() - W.requires_grad = True - - B_shape = (INPUT_SIZE) - B_master = torch.randn(B_shape, dtype=dtype, device=device) - dist.broadcast(B_master, src=0) - B = B_master.clone() - B.requires_grad = True - - layer.weight = Parameter(W) - layer.bias = Parameter(B) - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - W_master = W_master.clone() - W_master.requires_grad = True - B_master = B_master.clone() - B_master.requires_grad = True - C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master - C = C_master.clone() - - check_equal(out, C) - print_rank_0('linear_row forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] - check_equal(A_grad, A.grad) - - W_grad = W_master.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = B_master.grad - check_equal(B_grad, layer.bias.grad) - - print_rank_0('linear_row backward: pass') - - -def check_embed(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] - embed.weight.data.copy_(weight) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = embed(A) - - A_master = A_master.clone() - C_master = embed_master(A_master) - C = C_master.clone() - check_equal(out, C) - print_rank_0('embed forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - grad_master = grad_master.clone() - C_master.backward(grad_master) - - B_grad = embed_master.weight.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] - check_equal(B_grad, embed.weight.grad) - print_rank_0('embed backward: pass') - - -def check_vocab_parallel_embed(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=0)[i] - embed.weight.data.copy_(weight) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = embed(A) - - A_master = A_master.clone() - C_master = embed_master(A_master) - C = C_master.clone() - check_equal(out, C) - print_rank_0('vocab parallel embed forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - torch.distributed.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - grad_master = grad_master.clone() - C_master.backward(grad_master) - - B_grad = embed_master.weight.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, embed.weight.grad) - print_rank_0('vocab parallel embed backward: pass') - - -def check_classifier_no_given_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - env.parallel_input_1d = False - parallel_input_1d = env.parallel_input_1d - layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) - layer_master = layer_master.to(dtype).to(device) - - W_master = layer_master.weight.data - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=-1)[i] - layer.weight.data.copy_(W) - B_master = layer_master.bias.data - dist.broadcast(B_master, src=0) - B = B_master.clone() - layer.bias.data.copy_(B) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - if parallel_input_1d: - A = torch.chunk(A_master, DEPTH, dim=-1)[i] - A = A.clone() - else: - A = A_master.clone() - A.requires_grad = True - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - C_master = layer_master(A_master) - C = C_master.clone() - - check_equal(out, C) - print_rank_0('classifier (no given weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - if parallel_input_1d: - A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] - check_equal(A_grad, A.grad) - - W_grad = layer_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = layer_master.bias.grad - check_equal(B_grad, layer.bias.grad) - - print_rank_0('classifier (no given weight) backward: pass') - - -def check_vocab_parallel_classifier_no_given_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) - layer_master = layer_master.to(dtype).to(device) - - W_master = layer_master.weight.data - dist.broadcast(W_master, src=0) - W = torch.chunk(W_master, DEPTH, dim=0)[i] - layer.weight.data.copy_(W) - B_master = layer_master.bias.data - dist.broadcast(B_master, src=0) - B = torch.chunk(B_master, DEPTH, dim=0)[i] - layer.bias.data.copy_(B) - - A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) - A_master = torch.randn(A_shape, dtype=dtype, device=device) - dist.broadcast(A_master, src=0) - A = A_master.clone() - A.requires_grad = True - - out = layer(A) - - A_master = A_master.clone() - A_master.requires_grad = True - C_master = layer_master(A_master) - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - - check_equal(out, C) - print_rank_0('vocab parallel classifier (no given weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - A_grad = A_master.grad - check_equal(A_grad, A.grad) - - W_grad = layer_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, layer.weight.grad) - - B_grad = layer_master.bias.grad - B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] - check_equal(B_grad, layer.bias.grad) - - print_rank_0('vocab parallel classifier (no given weight) backward: pass') - - -def check_classifier_given_embed_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] - embed.weight.data.copy_(weight) - - env.parallel_input_1d = False - layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = layer(embed(A)) - - A_master = A_master.clone() - C_master = layer_master(embed_master(A_master)) - C = C_master.clone() - check_equal(out, C) - print_rank_0('classifier (given embed weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = grad_master.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - - W_grad = embed_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] - check_equal(W_grad, embed.weight.grad) - - print_rank_0('classifier (given embed weight) backward: pass') - - -def check_vocab_parallel_classifier_given_embed_weight(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) - embed = embed.to(dtype).to(device) - embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) - embed_master = embed_master.to(dtype).to(device) - - weight_master = embed_master.weight.data - torch.distributed.broadcast(weight_master, src=0) - weight = torch.chunk(weight_master, DEPTH, dim=0)[i] - embed.weight.data.copy_(weight) - - env.parallel_input_1d = False - layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) - layer.to(dtype).to(device) - - layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) - layer_master = layer_master.to(dtype).to(device) - - A_shape = (BATCH_SIZE, SEQ_LENGTH) - A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) - torch.distributed.broadcast(A_master, src=0) - A = A_master.clone() - out = layer(embed(A)) - - A_master = A_master.clone() - C_master = layer_master(embed_master(A_master)) - C = torch.chunk(C_master, DEPTH, dim=-1)[i] - check_equal(out, C) - print_rank_0('vocab parallel classifier (given embed weight) forward: pass') - - grad_shape = C_master.shape - grad_master = torch.randn(grad_shape, dtype=dtype, device=device) - dist.broadcast(grad_master, src=0) - grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] - grad = grad.clone() - out.backward(grad) - - grad_master = grad_master.clone() - C_master.backward(grad_master) - - W_grad = embed_master.weight.grad - W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] - check_equal(W_grad, embed.weight.grad) - - print_rank_0('vocab parallel classifier (given embed weight) backward: pass') - - -def check_vocab_parallel_loss(): - device = get_current_device() - dtype = torch.float32 - - i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - - criterion = VocabParallelCrossEntropyLoss1D() - criterion_master = torch.nn.CrossEntropyLoss() - - out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) - out_master = torch.randn(out_shape, dtype=dtype, device=device) - target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) - torch.distributed.broadcast(out_master, src=0) - torch.distributed.broadcast(target_master, src=0) - out = torch.chunk(out_master, DEPTH, dim=-1)[i] - out = out.clone() - out.requires_grad = True - - loss = criterion(out, target_master) - - out_master = out_master.clone() - out_master.requires_grad = True - loss_master = criterion_master(out_master, target_master) - check_equal(loss, loss_master) - print_rank_0('vocab parallel loss forward: pass') - - loss.backward() - loss_master.backward() - - out_grad = out_master.grad - out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] - check_equal(out_grad, out.grad) - print_rank_0('vocab parallel loss backward: pass') +import torch +import torch.distributed as dist +from torch.nn import Parameter + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.global_variables import tensor_parallel_env as env +from colossalai.nn import ( + Classifier1D, + Embedding1D, + Linear1D_Col, + Linear1D_Row, + VanillaClassifier, + VocabParallelClassifier1D, + VocabParallelCrossEntropyLoss1D, + VocabParallelEmbedding1D, +) +from colossalai.utils import get_current_device, print_rank_0 + +from .common import BATCH_SIZE, DEPTH, HIDDEN_SIZE, NUM_CLASSES, SEQ_LENGTH, VOCAB_SIZE, check_equal + + +def check_linear_col(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Col(INPUT_SIZE, OUTPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + W_shape = (OUTPUT_SIZE, INPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (OUTPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + B = B.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('linear_col forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_col backward: pass') + + +def check_linear_row(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + A.requires_grad = True + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + W.requires_grad = True + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + B.requires_grad = True + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + W_master = W_master.clone() + W_master.requires_grad = True + B_master = B_master.clone() + B_master.requires_grad = True + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device()) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = W_master.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = B_master.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('linear_row backward: pass') + + +def check_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('embed backward: pass') + + +def check_vocab_parallel_embed(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = embed(A) + + A_master = A_master.clone() + C_master = embed_master(A_master) + C = C_master.clone() + check_equal(out, C) + print_rank_0('vocab parallel embed forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + torch.distributed.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + grad_master = grad_master.clone() + C_master.backward(grad_master) + + B_grad = embed_master.weight.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, embed.weight.grad) + print_rank_0('vocab parallel embed backward: pass') + + +def check_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + env.parallel_input_1d = False + parallel_input_1d = env.parallel_input_1d + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = B_master.clone() + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + if parallel_input_1d: + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + else: + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = C_master.clone() + + check_equal(out, C) + print_rank_0('classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + if parallel_input_1d: + A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[i] + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + check_equal(B_grad, layer.bias.grad) + + print_rank_0('classifier (no given weight) backward: pass') + + +def check_vocab_parallel_classifier_no_given_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + layer = VocabParallelClassifier1D(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, VOCAB_SIZE, bias=True) + layer_master = layer_master.to(dtype).to(device) + + W_master = layer_master.weight.data + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=0)[i] + layer.weight.data.copy_(W) + B_master = layer_master.bias.data + dist.broadcast(B_master, src=0) + B = torch.chunk(B_master, DEPTH, dim=0)[i] + layer.bias.data.copy_(B) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = A_master.clone() + A.requires_grad = True + + out = layer(A) + + A_master = A_master.clone() + A_master.requires_grad = True + C_master = layer_master(A_master) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + + check_equal(out, C) + print_rank_0('vocab parallel classifier (no given weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + A_grad = A_master.grad + check_equal(A_grad, A.grad) + + W_grad = layer_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, layer.weight.grad) + + B_grad = layer_master.bias.grad + B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i] + check_equal(B_grad, layer.bias.grad) + + print_rank_0('vocab parallel classifier (no given weight) backward: pass') + + +def check_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = Embedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=-1)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = Classifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = C_master.clone() + check_equal(out, C) + print_rank_0('classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = grad_master.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_classifier_given_embed_weight(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + embed = VocabParallelEmbedding1D(VOCAB_SIZE, HIDDEN_SIZE) + embed = embed.to(dtype).to(device) + embed_master = torch.nn.Embedding(VOCAB_SIZE, HIDDEN_SIZE) + embed_master = embed_master.to(dtype).to(device) + + weight_master = embed_master.weight.data + torch.distributed.broadcast(weight_master, src=0) + weight = torch.chunk(weight_master, DEPTH, dim=0)[i] + embed.weight.data.copy_(weight) + + env.parallel_input_1d = False + layer = VocabParallelClassifier1D(HIDDEN_SIZE, NUM_CLASSES, weight=embed.weight, bias=False) + layer.to(dtype).to(device) + + layer_master = VanillaClassifier(HIDDEN_SIZE, NUM_CLASSES, weight=embed_master.weight, bias=False) + layer_master = layer_master.to(dtype).to(device) + + A_shape = (BATCH_SIZE, SEQ_LENGTH) + A_master = torch.randint(VOCAB_SIZE, A_shape, device=device) + torch.distributed.broadcast(A_master, src=0) + A = A_master.clone() + out = layer(embed(A)) + + A_master = A_master.clone() + C_master = layer_master(embed_master(A_master)) + C = torch.chunk(C_master, DEPTH, dim=-1)[i] + check_equal(out, C) + print_rank_0('vocab parallel classifier (given embed weight) forward: pass') + + grad_shape = C_master.shape + grad_master = torch.randn(grad_shape, dtype=dtype, device=device) + dist.broadcast(grad_master, src=0) + grad = torch.chunk(grad_master, DEPTH, dim=-1)[i] + grad = grad.clone() + out.backward(grad) + + grad_master = grad_master.clone() + C_master.backward(grad_master) + + W_grad = embed_master.weight.grad + W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i] + check_equal(W_grad, embed.weight.grad) + + print_rank_0('vocab parallel classifier (given embed weight) backward: pass') + + +def check_vocab_parallel_loss(): + device = get_current_device() + dtype = torch.float32 + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + criterion = VocabParallelCrossEntropyLoss1D() + criterion_master = torch.nn.CrossEntropyLoss() + + out_shape = (BATCH_SIZE, SEQ_LENGTH, NUM_CLASSES) + out_master = torch.randn(out_shape, dtype=dtype, device=device) + target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, SEQ_LENGTH), dtype=torch.long, device=device) + torch.distributed.broadcast(out_master, src=0) + torch.distributed.broadcast(target_master, src=0) + out = torch.chunk(out_master, DEPTH, dim=-1)[i] + out = out.clone() + out.requires_grad = True + + loss = criterion(out, target_master) + + out_master = out_master.clone() + out_master.requires_grad = True + loss_master = criterion_master(out_master, target_master) + check_equal(loss, loss_master) + print_rank_0('vocab parallel loss forward: pass') + + loss.backward() + loss_master.backward() + + out_grad = out_master.grad + out_grad = torch.chunk(out_grad, DEPTH, dim=-1)[i] + check_equal(out_grad, out.grad) + print_rank_0('vocab parallel loss backward: pass') + + +@torch.no_grad() +def check_linear_row_stream_inference(): + device = get_current_device() + dtype = torch.float32 + INPUT_SIZE = HIDDEN_SIZE + OUTPUT_SIZE = 2 * HIDDEN_SIZE + + i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + + assert HIDDEN_SIZE % 2 == 0 + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2) + + A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) + A_master = torch.randn(A_shape, dtype=dtype, device=device) + dist.broadcast(A_master, src=0) + A = torch.chunk(A_master, DEPTH, dim=-1)[i] + A = A.clone() + + W_shape = (INPUT_SIZE, OUTPUT_SIZE) + W_master = torch.randn(W_shape, dtype=dtype, device=device) + dist.broadcast(W_master, src=0) + W = torch.chunk(W_master, DEPTH, dim=-1)[i] + W = W.clone() + + B_shape = (INPUT_SIZE) + B_master = torch.randn(B_shape, dtype=dtype, device=device) + dist.broadcast(B_master, src=0) + B = B_master.clone() + + layer.weight = Parameter(W) + layer.bias = Parameter(B) + layer.chunk_weight() + out = layer(A) + + A_master = A_master.clone() + W_master = W_master.clone() + B_master = B_master.clone() + C_master = torch.matmul(A_master, W_master.transpose(0, 1)) + B_master + C = C_master.clone() + + check_equal(out, C) + print_rank_0('linear_row forward: pass') diff --git a/tests/test_layers/test_1d/test_1d.py b/tests/test_layers/test_1d/test_1d.py index cbdcb1b72..897590f0d 100644 --- a/tests/test_layers/test_1d/test_1d.py +++ b/tests/test_layers/test_1d/test_1d.py @@ -1,46 +1,49 @@ -#!/usr/bin/env python -# -*- encoding: utf-8 -*- - -from functools import partial - -import pytest -import torch -import torch.multiprocessing as mp -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers -from colossalai.initialize import launch -from colossalai.utils import free_port -from colossalai.testing import rerun_if_address_is_in_use -from checks_1d.check_layer_1d import * - -CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) - - -def check_layer(rank, world_size, port): - disable_existing_loggers() - launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - check_linear_col() - check_linear_row() - check_embed() - check_vocab_parallel_embed() - check_classifier_no_given_weight() - check_vocab_parallel_classifier_no_given_weight() - check_classifier_given_embed_weight() - check_vocab_parallel_classifier_given_embed_weight() - check_vocab_parallel_loss() - - gpc.destroy() - torch.cuda.empty_cache() - - -@pytest.mark.dist -@rerun_if_address_is_in_use() -def test_1d(): - world_size = 4 - run_func = partial(check_layer, world_size=world_size, port=free_port()) - mp.spawn(run_func, nprocs=world_size) - - -if __name__ == '__main__': - test_1d() +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +from functools import partial + +import pytest +import torch +import torch.multiprocessing as mp +from checks_1d.check_layer_1d import * + +from colossalai.core import global_context as gpc +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils import free_port + +CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=4, mode='1d')),) + + +def check_layer(rank, world_size, port): + disable_existing_loggers() + launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + check_linear_col() + check_linear_row() + check_embed() + check_vocab_parallel_embed() + check_classifier_no_given_weight() + check_vocab_parallel_classifier_no_given_weight() + check_classifier_given_embed_weight() + check_vocab_parallel_classifier_given_embed_weight() + check_vocab_parallel_loss() + + check_linear_row_stream_inference() + + gpc.destroy() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_1d(): + world_size = 4 + run_func = partial(check_layer, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_1d()