diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 2fd10de06..06544401c 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -17,6 +17,64 @@ import random import os import numpy as np +# Hack huggingface Bert ModelOutput +# Make it available to our ColoTensor +from transformers.file_utils import ModelOutput +from dataclasses import fields +def post_init_colo(self): + class_fields = fields(self) + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + def is_tensor_with_colo(x): + """ + Tests if `x` is a `ColoTensor` or `torch.Tensor`. + """ + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, ColoTensor) + + if other_fields_are_none and not is_tensor_with_colo(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for element in iterator: + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + +ModelOutput.__post_init__ = post_init_colo +# complete the hack def set_seed(seed): random.seed(seed) @@ -64,7 +122,7 @@ def run_1d_col_tp(): model_torch = model_torch.cuda() # A naive way to set spec for all weights in Linear - for name, p in named_params_with_colotensor(model): + for name, p in model.colo_named_parameters(): if not isinstance(p, ColoTensor): continue if 'proj1' in name and ('weight' in name or 'bias' in name): @@ -249,6 +307,60 @@ def run_1d_row_tp(): if i > 5: break +def run_bert_1d(): + get_components_func = non_distributed_component_funcs.get_callable('bert') + model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() + device = get_current_device() + + set_seed(1) + with ColoInitContext(device=device): + model = model_builder(checkpoint=True) + + # parallel_action_list_row = [ + # ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) + # ] + # spec_row = TensorSpec(parallel_action_list_row) + + parallel_action_list_col = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_col = TensorSpec(parallel_action_list_col) + + parallel_action_list_embedding_col = [ + ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) + ] + spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) + + for name, p in model.colo_named_parameters(): + if not isinstance(p, ColoTensor): + continue + #print(name) + if 'classifier' in name and ('weight' in name or 'bias' in name): + p.set_spec(spec_col) + if '_embeddings' in name and 'weight' in name: + p.set_spec(spec_embedding_col) + # for name, p in model.colo_named_parameters(): + # if not isinstance(p, ColoTensor): + # continue + # print(f"{name}: is_gathered {p.is_gathered()}") + + model = model.cuda() + + for i, (data, label) in enumerate(train_dataloader): + if i > 5: + break + data = data.to(device) + label = label.to(device) + + model.train() + if criterion: + output = model(data) + loss = criterion(output, label) + else: + output = model(data, label) + loss = output + + loss.backward() def run_dist(rank, world_size, port): config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) @@ -256,16 +368,30 @@ def run_dist(rank, world_size, port): run_1d_row_tp() run_1d_col_tp() +def run_dist_bert(rank, world_size, port): + config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_bert_1d() @pytest.mark.dist -@parameterize('world_size', [1, 4]) +@pytest.mark.parametrize('world_size', [1, 4]) @rerun_if_address_is_in_use() def test_simple_net(world_size): run_func = partial(run_dist, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) +@pytest.mark.dist +#@pytest.mark.parametrize('world_size', [1, 4]) +#Don't really add it to pytest now. After finishing Classifier and Loss, I(jzy) will remove this annotation. +@parameterize('world_size', [1]) +@rerun_if_address_is_in_use() +def test_bert(world_size): + run_func = partial(run_dist_bert, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + if __name__ == '__main__': # test_simple_net() - test_model_parameters() + # test_model_parameters() # test_colo_optimizer() + test_bert()