From 8e3d0ad8f1d098baf9731eb4f57b90c6c2c0a34e Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 19 May 2022 18:57:56 +0800 Subject: [PATCH] [unit test] refactor test tensor (#1005) * polish test_gpt * update op unit tests * update test model --- tests/components_to_test/__init__.py | 2 +- tests/components_to_test/gpt.py | 79 +++++++++++ tests/test_tensor/_utils/_util.py | 30 +++++ tests/test_tensor/test_addmm_tp.py | 26 +--- tests/test_tensor/test_embedding_tp.py | 23 +--- tests/test_tensor/test_gpt.py | 175 ++----------------------- tests/test_tensor/test_linear_tp.py | 26 +--- tests/test_tensor/test_model.py | 72 +--------- 8 files changed, 143 insertions(+), 290 deletions(-) create mode 100644 tests/components_to_test/gpt.py diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index 099bbe813..f87d35ff9 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1 @@ -from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net +from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt diff --git a/tests/components_to_test/gpt.py b/tests/components_to_test/gpt.py new file mode 100644 index 000000000..4d72180d8 --- /dev/null +++ b/tests/components_to_test/gpt.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from .registry import non_distributed_component_funcs +from transformers import GPT2Config, GPT2LMHeadModel +from .utils.dummy_data_generator import DummyDataGenerator +from colossalai.utils.cuda import get_current_device + + +class DummyDataLoader(DummyDataGenerator): + vocab_size = 50304 + batch_size = 4 + seq_len = 1024 + + def generate(self): + input_ids = torch.randint(0, + DummyDataLoader.vocab_size, (DummyDataLoader.batch_size, DummyDataLoader.seq_len), + device=get_current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +class GPTLMModel(nn.Module): + + def __init__(self, + hidden_size=768, + num_layers=12, + num_attention_heads=12, + max_seq_len=1024, + vocab_size=50304, + checkpoint=False): + super().__init__() + self.checkpoint = checkpoint + self.model = GPT2LMHeadModel( + GPT2Config(n_embd=hidden_size, + n_layer=num_layers, + n_head=num_attention_heads, + n_positions=max_seq_len, + n_ctx=max_seq_len, + vocab_size=vocab_size, + resid_pdrop=0.0, + embd_pdrop=0.0, + attn_pdrop=0.0)) + if checkpoint: + self.model.gradient_checkpointing_enable() + + def forward(self, input_ids, attention_mask): + # Only return lm_logits + return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] + + +def gpt2_s(checkpoint=True): + return GPTLMModel(checkpoint=checkpoint) + + +def gpt2_m(checkpoint=True): + return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) + + +class GPTLMLoss(nn.Module): + + def __init__(self): + super().__init__() + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, logits, labels): + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + +@non_distributed_component_funcs.register(name='gpt2') +def get_training_components(): + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = GPTLMLoss() + return gpt2_s, trainloader, testloader, torch.optim.Adam, criterion diff --git a/tests/test_tensor/_utils/_util.py b/tests/test_tensor/_utils/_util.py index 6fd595aa4..32f912955 100644 --- a/tests/test_tensor/_utils/_util.py +++ b/tests/test_tensor/_utils/_util.py @@ -1,5 +1,19 @@ +import os +import random +import numpy as np import torch import torch.distributed as dist +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode + + +def set_seed(seed): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True def check_equal(A, B): @@ -25,3 +39,19 @@ def broadcast_tensor_chunk(tensor, chunk_size=1, local_rank=0): def tensor_equal(A, B): return torch.allclose(A, B, rtol=1e-3, atol=1e-1) + + +def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): + assert tensor.ndim == shard.ndim + if tensor.shape == shard.shape: + return tensor_equal(tensor, shard) + else: + dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) + if dims_not_eq.numel() == 1: + # 1D shard + dim = dims_not_eq.item() + world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) + rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) + return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) + else: + raise NotImplementedError diff --git a/tests/test_tensor/test_addmm_tp.py b/tests/test_tensor/test_addmm_tp.py index b5c19db10..b02f4baad 100644 --- a/tests/test_tensor/test_addmm_tp.py +++ b/tests/test_tensor/test_addmm_tp.py @@ -11,6 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from functools import partial from colossalai.core import global_context as gpc +from _utils import tensor_shard_equal, tensor_equal class Conv1D(nn.Module): @@ -45,13 +46,6 @@ def init_1d_row(weight, bias): weight.set_spec(spec) -def check_grad_1d_row(model: torch.nn.Module, weight, bias): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad) - assert torch.allclose(model.bias.grad, bias.grad) - - def init_1d_col(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), @@ -61,14 +55,7 @@ def init_1d_col(weight, bias): bias.set_spec(spec) -def check_grad_1d_col(model: torch.nn.Module, weight, bias): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad) - assert torch.allclose(model.bias.grad.chunk(size, -1)[rank], bias.grad) - - -def run_with_spec(spec_init_func, check_grad_func): +def run_with_spec(spec_init_func): model = Conv1D(4, 16).cuda() weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) @@ -76,18 +63,19 @@ def run_with_spec(spec_init_func, check_grad_func): x = torch.rand(2, 16).cuda() out = model(x) colo_out = torch.addmm(bias, x, weight) - assert torch.allclose(out, colo_out) + assert tensor_equal(out, colo_out) grad = torch.rand_like(out) out.backward(grad) colo_out.backward(grad) - check_grad_func(model, weight, bias) + tensor_shard_equal(model.weight.grad, weight.grad) + tensor_shard_equal(model.bias.grad, bias.grad) def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(init_1d_row, check_grad_1d_row) - run_with_spec(init_1d_col, check_grad_1d_col) + run_with_spec(init_1d_row) + run_with_spec(init_1d_col) @pytest.mark.dist diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py index 1c687d53d..71d0c52bc 100644 --- a/tests/test_tensor/test_embedding_tp.py +++ b/tests/test_tensor/test_embedding_tp.py @@ -12,6 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager +from _utils import tensor_equal, tensor_shard_equal def init_1d_row(weight): @@ -22,12 +23,6 @@ def init_1d_row(weight): weight.set_spec(spec) -def check_grad_1d_row(model: torch.nn.Module, weight): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad) - - def init_1d_col(weight): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), @@ -36,31 +31,25 @@ def init_1d_col(weight): weight.set_spec(spec) -def check_grad_1d_col(model: torch.nn.Module, weight): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad) - - -def run_with_spec(spec_init_func, check_grad_func): +def run_with_spec(spec_init_func): model = torch.nn.Embedding(12, 32).cuda() weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) spec_init_func(weight) x = torch.tensor((0, 3, 6, 9)).cuda() out = model(x) colo_out = F.embedding(x, weight) - assert torch.allclose(out, colo_out) + assert tensor_equal(out, colo_out) grad = torch.rand_like(out) out.backward(grad) colo_out.backward(grad) - check_grad_func(model, weight) + assert tensor_shard_equal(model.weight.grad, weight.grad) def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(init_1d_row, check_grad_1d_row) - run_with_spec(init_1d_col, check_grad_1d_col) + run_with_spec(init_1d_row) + run_with_spec(init_1d_col) @pytest.mark.dist diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index 369671af8..9e1671280 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -1,142 +1,16 @@ import pytest import colossalai -import os -import random -import numpy as np -import torch -import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode -from transformers import GPT2Config, GPT2LMHeadModel import torch.multiprocessing as mp from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils import ColoInitContext -from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager, distspec +from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, distspec from colossalai.core import global_context as gpc from functools import partial -# Hack huggingface Bert ModelOutput -# Make it available to our ColoTensor -from transformers.file_utils import ModelOutput -from dataclasses import fields -from tests.test_tensor._utils import tensor_equal - - -def _post_init_colotensor(self): - class_fields = fields(self) - # Safety and consistency checks - if len(class_fields) == 0: - raise ValueError(f"{self.__class__.__name__} has no fields.") - if not all(field.default is None for field in class_fields[1:]): - raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") - - first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) - - def is_tensor_with_colo(x): - """ - Tests if `x` is a `ColoTensor` or `torch.Tensor`. - """ - if isinstance(x, torch.Tensor): - return True - - return isinstance(x, ColoTensor) - - if other_fields_are_none and not is_tensor_with_colo(first_field): - if isinstance(first_field, dict): - iterator = first_field.items() - first_field_iterator = True - else: - try: - iterator = iter(first_field) - first_field_iterator = True - except TypeError: - first_field_iterator = False - - # if we provided an iterator as first field and the iterator is a (key, value) iterator - # set the associated fields - if first_field_iterator: - for element in iterator: - if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)): - break - setattr(self, element[0], element[1]) - if element[1] is not None: - self[element[0]] = element[1] - elif first_field is not None: - self[class_fields[0].name] = first_field - else: - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v - - -ModelOutput.__post_init__ = _post_init_colotensor - - -class GPTLMModel(nn.Module): - - def __init__(self, - hidden_size=768, - num_layers=12, - num_attention_heads=12, - max_seq_len=1024, - vocab_size=50304, - checkpoint=False): - super().__init__() - self.checkpoint = checkpoint - self.model = GPT2LMHeadModel( - GPT2Config(n_embd=hidden_size, - n_layer=num_layers, - n_head=num_attention_heads, - n_positions=max_seq_len, - n_ctx=max_seq_len, - vocab_size=vocab_size, - resid_pdrop=0.0, - embd_pdrop=0.0, - attn_pdrop=0.0)) - if checkpoint: - self.model.gradient_checkpointing_enable() - - def forward(self, input_ids, attention_mask): - # Only return lm_logits - return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0] - - -def gpt2_s(checkpoint=True): - return GPTLMModel(checkpoint=checkpoint) - - -def gpt2_m(checkpoint=True): - return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint) - - -class GPTLMLoss(nn.Module): - - def __init__(self): - super().__init__() - self.loss_fn = nn.CrossEntropyLoss() - - def forward(self, logits, labels): - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) - - -def set_seed(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - - -def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask +from _utils import tensor_equal, tensor_shard_equal +from tests.components_to_test.registry import non_distributed_component_funcs def init_1d_row_spec(model): @@ -159,30 +33,6 @@ def init_1d_col_spec(model): p.set_spec(spec) -def check_tensor_equal_1d(tensor: torch.Tensor, shard: ColoTensor): - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - assert len(shard.spec.dist_spec.dims) == 1 - dim = shard.spec.dist_spec.dims[0] - assert torch.equal(tensor.chunk(world_size, dim)[rank], shard.torch_tensor()) - - -def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor): - assert tensor.ndim == shard.ndim - if tensor.shape == shard.shape: - return tensor_equal(tensor, shard) - else: - dims_not_eq = torch.nonzero(torch.tensor(tensor.shape) != torch.tensor(shard.shape)) - if dims_not_eq.numel() == 1: - # 1D shard - dim = dims_not_eq.item() - world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - return tensor_equal(tensor.chunk(world_size, dim)[rank], shard) - else: - raise NotImplementedError - - def check_param_equal(model, torch_model): for p, torch_p in zip(model.parameters(), torch_model.parameters()): assert tensor_shard_equal(torch_p, p) @@ -194,23 +44,20 @@ def check_grad_equal(model, torch_model): def run_gpt(init_spec_func): - BATCH_SIZE = 4 - SEQ_LEN = 1024 - VOCAB_SIZE = 50304 - NUM_STEPS = 1 - criterion = GPTLMLoss() + get_components_func = non_distributed_component_funcs.get_callable('gpt2') + model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() + with ColoInitContext(device=get_current_device()): - model = gpt2_s() + model = model_builder() model = model.cuda() - torch_model = gpt2_s().cuda() + torch_model = model_builder().cuda() for torch_p, p in zip(torch_model.parameters(), model.parameters()): torch_p.data.copy_(p) init_spec_func(model) check_param_equal(model, torch_model) model.train() torch_model.train() - for i in range(NUM_STEPS): - input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) + for i, (input_ids, attn_mask) in enumerate(train_dataloader): logits = model(input_ids, attn_mask) torch_logits = torch_model(input_ids, attn_mask) assert tensor_equal(torch_logits, logits) @@ -219,6 +66,8 @@ def run_gpt(init_spec_func): loss.backward() torch_loss.backward() check_grad_equal(model, torch_model) + if i > 0: + break def run_dist(rank, world_size, port): @@ -237,4 +86,4 @@ def test_gpt(world_size): if __name__ == '__main__': - test_gpt(1) + test_gpt(4) diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index a009ceca5..ac9a8ece0 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -13,6 +13,7 @@ from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils import free_port from colossalai.core import global_context as gpc from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager +from _utils import tensor_equal, tensor_shard_equal def init_1d_row(weight, bias): @@ -23,13 +24,6 @@ def init_1d_row(weight, bias): weight.set_spec(spec) -def check_grad_1d_row(model: torch.nn.Module, weight, bias): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, -1)[rank], weight.grad) - assert torch.allclose(model.bias.grad, bias.grad) - - def init_1d_col(weight, bias): spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), @@ -39,14 +33,7 @@ def init_1d_col(weight, bias): bias.set_spec(spec) -def check_grad_1d_col(model: torch.nn.Module, weight, bias): - rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - size = gpc.get_world_size(ParallelMode.PARALLEL_1D) - assert torch.allclose(model.weight.grad.chunk(size, 0)[rank], weight.grad) - assert torch.allclose(model.bias.grad.chunk(size, 0)[rank], bias.grad) - - -def run_with_spec(spec_init_func, check_grad_func): +def run_with_spec(spec_init_func): model = torch.nn.Linear(4, 8).cuda() weight = ColoTensor(torch.nn.Parameter(model.weight.detach())) bias = ColoTensor(torch.nn.Parameter(model.bias.detach())) @@ -54,18 +41,19 @@ def run_with_spec(spec_init_func, check_grad_func): x = torch.rand(2, 4).cuda() out = model(x) colo_out = F.linear(x, weight, bias) - assert torch.allclose(out, colo_out) + assert tensor_equal(out, colo_out) grad = torch.rand_like(out) out.backward(grad) colo_out.backward(grad) - check_grad_func(model, weight, bias) + assert tensor_shard_equal(model.weight.grad, weight.grad) + assert tensor_shard_equal(model.bias.grad, bias.grad) def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - run_with_spec(init_1d_row, check_grad_1d_row) - run_with_spec(init_1d_col, check_grad_1d_col) + run_with_spec(init_1d_row) + run_with_spec(init_1d_col) @pytest.mark.dist diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index bcaab6716..6a71242df 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -13,78 +13,8 @@ from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager from colossalai.context import ParallelMode from colossalai.core import global_context as gpc - from functools import partial -import random -import os -import numpy as np - -# Hack huggingface Bert ModelOutput -# Make it available to our ColoTensor -from transformers.file_utils import ModelOutput -from dataclasses import fields - - -def _post_init_colotensor(self): - class_fields = fields(self) - # Safety and consistency checks - if len(class_fields) == 0: - raise ValueError(f"{self.__class__.__name__} has no fields.") - if not all(field.default is None for field in class_fields[1:]): - raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") - - first_field = getattr(self, class_fields[0].name) - other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) - - def is_tensor_with_colo(x): - """ - Tests if `x` is a `ColoTensor` or `torch.Tensor`. - """ - if isinstance(x, torch.Tensor): - return True - - return isinstance(x, ColoTensor) - - if other_fields_are_none and not is_tensor_with_colo(first_field): - if isinstance(first_field, dict): - iterator = first_field.items() - first_field_iterator = True - else: - try: - iterator = iter(first_field) - first_field_iterator = True - except TypeError: - first_field_iterator = False - - # if we provided an iterator as first field and the iterator is a (key, value) iterator - # set the associated fields - if first_field_iterator: - for element in iterator: - if (not isinstance(element, (list, tuple)) or not len(element) == 2 or not isinstance(element[0], str)): - break - setattr(self, element[0], element[1]) - if element[1] is not None: - self[element[0]] = element[1] - elif first_field is not None: - self[class_fields[0].name] = first_field - else: - for field in class_fields: - v = getattr(self, field.name) - if v is not None: - self[field.name] = v - - -ModelOutput.__post_init__ = _post_init_colotensor -# complete the hack - - -def set_seed(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True +from _utils import set_seed def init_1d_row_linear(weight):