diff --git a/colossalai/shardformer/modeling/vit.py b/colossalai/shardformer/modeling/vit.py new file mode 100644 index 000000000..f28c13ad0 --- /dev/null +++ b/colossalai/shardformer/modeling/vit.py @@ -0,0 +1,337 @@ +import logging +from typing import Dict, List, Optional, Set, Tuple, Union + +import torch +from transformers.models.vit.modeling_vit import BaseModelOutput, ViTEncoder + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +def _encoder_forward( + encoder: ViTEncoder, + start_idx: int, + end_idx: int, + hidden_states: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + stage_manager: PipelineStageManager = None, +) -> Union[tuple, BaseModelOutput]: + + for i in range(start_idx, end_idx): + layer_module = encoder.layer[i] + + layer_head_mask = head_mask[i] if head_mask is not None else None + + if encoder.gradient_checkpointing and encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, False) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + layer_head_mask, + ) + else: + layer_outputs = layer_module(hidden_states, layer_head_mask, False) + + hidden_states = layer_outputs[0] + if not stage_manager.is_last_stage(): + return hidden_states + else: + if not return_dict: + return tuple(hidden_states) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=None, + attentions=None, + ) + + +def ViTModel_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + from transformers.models.vit.modeling_vit import BaseModelOutputWithPooling + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if output_attentions is not None: + logging.warning('Non-empty output_attentions is not supported for pipeline models at the moment.') + output_attentions = None + if output_hidden_states is not None: + logging.warning('Non-empty output_hidden_states is not supported for pipeline models at the moment.') + output_hidden_states = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + if stage_manager.is_first_stage(): + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) + expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype + if pixel_values.dtype != expected_dtype: + pixel_values = pixel_values.to(expected_dtype) + + embedding_output = self.embeddings(pixel_values, + bool_masked_pos=bool_masked_pos, + interpolate_pos_encoding=interpolate_pos_encoding) + else: + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + # Go through encoder + if not stage_manager.is_last_stage(): + hidden_states = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=embedding_output, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + return {'hidden_states': hidden_states} + else: + encoder_outputs = _encoder_forward( + encoder=self.encoder, + start_idx=stage_index[0], + end_idx=stage_index[1], + hidden_states=hidden_states, + head_mask=head_mask, + return_dict=return_dict, + stage_manager=stage_manager, + ) + + # Go through rest layers + sequence_output = encoder_outputs[0] + sequence_output = self.layernorm(sequence_output) + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + + if not return_dict: + head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) + return head_outputs + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + return pp_forward + + +def ViTForImageClassification_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + from transformers.models.vit.modeling_vit import ImageClassifierOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the image classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit( + pixel_values, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states, + ) + + # not last stage, return hidden_states + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # last stage + logits = self.classifier(sequence_output[:, 0, :]) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return ImageClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward + + +def ViTForMaskedImageModeling_pipeline_forward(stage_manager: PipelineStageManager, stage_index: List[int]): + + import math + + import torch.nn as nn + from transformers.models.vit.modeling_vit import ImageClassifierOutput, MaskedImageModelingOutput + + def pp_forward( + self, + pixel_values: Optional[torch.Tensor] = None, + bool_masked_pos: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + interpolate_pos_encoding: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.FloatTensor] = None, + ) -> Union[tuple, ImageClassifierOutput]: + r""" + bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`): + Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). + + Returns: + + Examples: + ```python + >>> from transformers import AutoImageProcessor, ViTForMaskedImageModeling + >>> import torch + >>> from PIL import Image + >>> import requests + + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") + >>> model = ViTForMaskedImageModeling.from_pretrained("google/vit-base-patch16-224-in21k") + + >>> num_patches = (model.config.image_size // model.config.patch_size) ** 2 + >>> pixel_values = image_processor(images=image, return_tensors="pt").pixel_values + >>> # create random boolean mask of shape (batch_size, num_patches) + >>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + + >>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos) + >>> loss, reconstructed_pixel_values = outputs.loss, outputs.reconstruction + >>> list(reconstructed_pixel_values.shape) + [1, 3, 224, 224] + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if bool_masked_pos is not None and (self.config.patch_size != self.config.encoder_stride): + raise ValueError( + "When `bool_masked_pos` is provided, `patch_size` must be equal to `encoder_stride` to ensure that " + "the reconstructed image has the same dimensions as the input." + f"Got `patch_size` = {self.config.patch_size} and `encoder_stride` = {self.config.encoder_stride}.") + + if not stage_manager.is_first_stage(): + assert hidden_states is not None, f"Current stage is {stage_manager.stage}, hidden_states should not be None" + + outputs = self.vit(pixel_values, + bool_masked_pos=bool_masked_pos, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + interpolate_pos_encoding=interpolate_pos_encoding, + return_dict=return_dict, + hidden_states=hidden_states) + if not stage_manager.is_last_stage(): + return outputs + else: + sequence_output = outputs[0] + + # Reshape to (batch_size, num_channels, height, width) + sequence_output = sequence_output[:, 1:] + batch_size, sequence_length, num_channels = sequence_output.shape + height = width = math.floor(sequence_length**0.5) + sequence_output = sequence_output.permute(0, 2, 1).reshape(batch_size, num_channels, height, width) + + # Reconstruct pixel values + reconstructed_pixel_values = self.decoder(sequence_output) + + masked_im_loss = None + if bool_masked_pos is not None: + size = self.config.image_size // self.config.patch_size + bool_masked_pos = bool_masked_pos.reshape(-1, size, size) + mask = (bool_masked_pos.repeat_interleave(self.config.patch_size, + 1).repeat_interleave(self.config.patch_size, + 2).unsqueeze(1).contiguous()) + reconstruction_loss = nn.functional.l1_loss(pixel_values, reconstructed_pixel_values, reduction="none") + masked_im_loss = (reconstruction_loss * mask).sum() / (mask.sum() + 1e-5) / self.config.num_channels + + if not return_dict: + output = (reconstructed_pixel_values,) + outputs[1:] + return ((masked_im_loss,) + output) if masked_im_loss is not None else output + + return MaskedImageModelingOutput( + loss=masked_im_loss, + reconstruction=reconstructed_pixel_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + return pp_forward diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index b31f1b35f..d00a03c92 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -75,6 +75,14 @@ _POLICY_LIST = { "transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification": PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"), + # ViT + "transformers.models.vit.modeling_vit.ViTModel": + PolicyLocation(file_name="vit", class_name="ViTModelPolicy"), + "transformers.models.vit.modeling_vit.ViTForImageClassification": + PolicyLocation(file_name="vit", class_name="ViTForImageClassificationPolicy"), + "transformers.models.vit.modeling_vit.ViTForMaskedImageModeling": + PolicyLocation(file_name="vit", class_name="ViTForMaskedImageModelingPolicy"), + # OPT "transformers.models.opt.modeling_opt.OPTModel": PolicyLocation(file_name="opt", class_name="OPTModelPolicy"), diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index 3f6bbd106..47f2c58fc 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -1,12 +1,18 @@ -from typing import Dict, Union +from functools import partial +from typing import Callable, Dict, List, Union import torch.nn as nn -from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row +import colossalai.shardformer.layer as col_nn +from ..modeling.vit import ( + ViTForImageClassification_pipeline_forward, + ViTForMaskedImageModeling_pipeline_forward, + ViTModel_pipeline_forward, +) from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ['ViTPolicy'] +__all__ = ['ViTPolicy', 'ViTModelPolicy', 'ViTForImageClassificationPolicy', 'ViTForMaskedImageModelingPolicy'] class ViTPolicy(Policy): @@ -15,96 +21,203 @@ class ViTPolicy(Policy): pass def preprocess(self): - # Resize embedding - vocab_size = self.model.config.vocab_size - world_size = self.shard_config.tensor_parallel_size - - if vocab_size % world_size != 0: - new_vocab_size = vocab_size + world_size - vocab_size % world_size - self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer - base_policy = { - ViTEmbeddings: - ModulePolicyDescription(sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="dropout", - target_module=DropoutForReplicatedInput, - ) - ]), - ViTLayer: - ModulePolicyDescription(attribute_replacement={ - "attention.attention.num_attention_heads": - self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, - "attention.attention.all_head_size": - self.model.config.hidden_size // self.shard_config.tensor_parallel_size, - }, - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="attention.attention.query", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.key", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.value", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="attention.attention.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="attention.output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="attention.output.dropout", - target_module=DropoutForParallelInput, - ), - SubModuleReplacementDescription( - suffix="intermediate.dense", - target_module=Linear1D_Col, - ), - SubModuleReplacementDescription( - suffix="output.dense", - target_module=Linear1D_Row, - ), - SubModuleReplacementDescription( - suffix="output.dropout", - target_module=DropoutForParallelInput, - ), - ]), - } + policy = {} - # optimization configuration - if self.shard_config.enable_fused_normalization: - base_policy[ViTAttention].sub_module_replacement.extend([ - SubModuleReplacementDescription( - suffix="layernorm_before", - target_module=FusedLayerNorm, - ), - SubModuleReplacementDescription( - suffix="layernorm_after", - target_module=FusedLayerNorm, - ) - ]) - base_policy[ViTModel].sub_module_replacement.append( - SubModuleReplacementDescription( - suffix="layernorm", - target_module=FusedLayerNorm, - )) + if self.shard_config.enable_tensor_parallelism: + policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=col_nn.DropoutForReplicatedInput, + ) + ]) - return base_policy + policy[ViTLayer] = ModulePolicyDescription(attribute_replacement={ + "attention.attention.num_attention_heads": + self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + "attention.attention.all_head_size": + self.model.config.hidden_size // self.shard_config.tensor_parallel_size, + }, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="attention.attention.query", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.key", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.value", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="attention.attention.dropout", + target_module=col_nn.DropoutForParallelInput, + ), + SubModuleReplacementDescription( + suffix="attention.output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="attention.output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + SubModuleReplacementDescription( + suffix="intermediate.dense", + target_module=col_nn.Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="output.dense", + target_module=col_nn.Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="output.dropout", + target_module=col_nn.DropoutForReplicatedInput, + ), + ]) + + return policy def new_model_class(self): return None def postprocess(self): return self.model + + def get_held_layers(self) -> List[nn.Module]: + """Get pipeline layers for current stage.""" + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + stage_manager = self.pipeline_stage_manager + + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.embeddings) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.encoder.layer[start_idx:end_idx]) + return held_layers + + def set_pipeline_forward(self, model_cls: nn.Module, pipeline_forward: Callable, policy: Dict): + if self.pipeline_stage_manager: + stage_manager = self.pipeline_stage_manager + if self.model.__class__.__name__ == 'ViTModel': + module = self.model + else: + module = self.model.vit + + layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = {'forward': pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=model_cls) + + +# ViTModel +class ViTModelPolicy(ViTPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(module.pooler) + + return held_layers + + +# ViTForImageClassification +class ViTForImageClassificationPolicy(ViTPolicy): + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForImageClassification, ViTModel + + policy = super().module_policy() + if self.shard_config.enable_tensor_parallelism: + new_item = { + ViTForImageClassification: + ModulePolicyDescription(sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)) + ]) + } + policy.update(new_item) + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForImageClassification, + pipeline_forward=ViTForImageClassification_pipeline_forward, + policy=policy) + + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.classifier) + + return held_layers + + +# ViTForMaskedImageModeling +class ViTForMaskedImageModelingPolicy(ViTPolicy): + + def __init__(self) -> None: + super().__init__() + + def module_policy(self): + from transformers.models.vit.modeling_vit import ViTForMaskedImageModeling, ViTModel + + policy = super().module_policy() + + if self.shard_config.pipeline_stage_manager is not None: + self.set_pipeline_forward(model_cls=ViTModel, pipeline_forward=ViTModel_pipeline_forward, policy=policy) + self.set_pipeline_forward(model_cls=ViTForMaskedImageModeling, + pipeline_forward=ViTForMaskedImageModeling_pipeline_forward, + policy=policy) + return policy + + def get_held_layers(self) -> List[nn.Module]: + held_layers = super().get_held_layers() + assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None" + + module = self.model.vit + stage_manager = self.pipeline_stage_manager + if stage_manager.is_last_stage(): + held_layers.append(module.layernorm) + held_layers.append(self.model.decoder) + + return held_layers diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index 4aa01abe1..a298767d1 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -5,3 +5,4 @@ from .gpt import * from .llama import * from .opt import * from .t5 import * +from .vit import * diff --git a/tests/kit/model_zoo/transformers/vit.py b/tests/kit/model_zoo/transformers/vit.py new file mode 100644 index 000000000..93a8d6c61 --- /dev/null +++ b/tests/kit/model_zoo/transformers/vit.py @@ -0,0 +1,68 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +# =============================== +# Register single-sentence VIT +# =============================== + +config = transformers.ViTConfig( + num_hidden_layers=4, + # hidden_size=128, + # intermediate_size=256, + num_attention_heads=4) + + +# define data gen function +def data_gen(): + pixel_values = torch.randn(1, 3, 224, 224) + return dict(pixel_values=pixel_values) + + +def data_gen_for_image_classification(): + data = data_gen() + data['labels'] = torch.tensor([0]) + return data + + +def data_gen_for_masked_image_modeling(): + data = data_gen() + num_patches = (config.image_size // config.patch_size)**2 + bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool() + data['bool_masked_pos'] = bool_masked_pos + return data + + +# define output transform function +output_transform_fn = lambda x: x + +# function to get the loss +loss_fn_for_vit_model = lambda x: x.pooler_output.mean() +loss_fn_for_image_classification = lambda x: x.logits.mean() +loss_fn_for_masked_image_modeling = lambda x: x.loss + +# register the following models +# transformers.ViTModel, +# transformers.ViTForMaskedImageModeling, +# transformers.ViTForImageClassification, +model_zoo.register(name='transformers_vit', + model_fn=lambda: transformers.ViTModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_vit_model, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_masked_image_modeling', + model_fn=lambda: transformers.ViTForMaskedImageModeling(config), + data_gen_fn=data_gen_for_masked_image_modeling, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_masked_image_modeling, + model_attribute=ModelAttribute(has_control_flow=True)) + +model_zoo.register(name='transformers_vit_for_image_classification', + model_fn=lambda: transformers.ViTForImageClassification(config), + data_gen_fn=data_gen_for_image_classification, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_image_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_shardformer/test_model/test_shard_vit.py b/tests/test_shardformer/test_model/test_shard_vit.py index af1605b6b..2b02c83e0 100644 --- a/tests/test_shardformer/test_model/test_shard_vit.py +++ b/tests/test_shardformer/test_model/test_shard_vit.py @@ -1,9 +1,18 @@ +import os + import pytest import torch import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor +from colossalai.testing import ( + assert_hf_output_close, + clear_cache_before_run, + parameterize, + rerun_if_address_is_in_use, + spawn, +) from tests.kit.model_zoo import model_zoo from tests.test_shardformer.test_model._utils import build_model, run_forward @@ -12,44 +21,58 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo # check forward org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - assert_hf_output_close(org_output, shard_output) - + assert_hf_output_close(org_output, shard_output, atol=1e-3, rtol=1e-3) # do backward org_loss.backward() shard_loss.backward() - # check grad - org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad - shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad - - shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] - shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) - all_shard_grad = torch.cat(shard_grad_list, dim=0) - assert torch.allclose(org_loss, shard_loss, atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + + # unwrap model + if org_model.__class__.__name__ == 'ViTModel': + vit_model = org_model + shard_vit_model = sharded_model + else: + vit_model = org_model.vit + shard_vit_model = sharded_model.vit + + # check attention grad + org_grad = vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_grad = shard_vit_model.encoder.layer[0].attention.attention.query.weight.grad + shard_weight = shard_vit_model.encoder.layer[0].attention.attention.query.weight + + if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + else: + all_shard_grad = shard_grad assert torch.allclose(org_grad, all_shard_grad, - atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}" + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +@parameterize('enable_fused_normalization', [True, False]) +@parameterize('enable_tensor_parallelism', [True, False]) +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism): + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + torch.cuda.empty_cache() def check_vit(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - - sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') - for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) - - torch.cuda.empty_cache() + run_vit_test() @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() @clear_cache_before_run() def test_vit(): - spawn(check_vit, 4) + spawn(check_vit, 2) if __name__ == "__main__": diff --git a/tests/test_shardformer/test_model/test_shard_vit_pipeline.py b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py new file mode 100644 index 000000000..114992a2a --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_vit_pipeline.py @@ -0,0 +1,74 @@ +import pytest +import torch + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.logging import disable_existing_loggers +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_pipeline_model + + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # TODO: add tests for forward/backward later + pass + + +@parameterize('enable_tensor_parallelism', [False]) +@parameterize('enable_fused_normalization', [False]) +@parameterize('use_lazy_init', [False]) +#TODO: merge this into test_shard_vit +def run_vit_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init): + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + + sub_model_zoo = model_zoo.get_sub_registry('transformers_vit') + + for name, (model_fn, data_gen_fn, _, _, _) in sub_model_zoo.items(): + + inputs = data_gen_fn() + inputs = {k: v.cuda() for k, v in inputs.items()} + pixel_values = inputs['pixel_values'] + batch_size = len(pixel_values) + hidden_size = 768 + hidden_state_shape = (batch_size, 197, hidden_size) + + if not stage_manager.is_first_stage(): + # change inputs if not the first stage + hidden_states = torch.randn(*hidden_state_shape).cuda() + # inputs['pixel_values'] = None + inputs['hidden_states'] = hidden_states + + _, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + sharded_model.train() + + output = sharded_model(**inputs) + if stage_manager.is_last_stage(): + if name != 'transformers_vit': + assert output.loss is not None + else: + assert output['hidden_states'].shape == hidden_state_shape, \ + f'hidden_states shape is not correct, output:{output["hidden_states"].shape} is not equal to hidden_state:{hidden_state_shape}' + + torch.cuda.empty_cache() + + +def check_vit(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_vit_test() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_vit(): + spawn(check_vit, 4) + + +if __name__ == "__main__": + test_vit()