diff --git a/colossalai/pipeline/policy/__init__.py b/colossalai/pipeline/policy/__init__.py new file mode 100644 index 000000000..fd9e6e045 --- /dev/null +++ b/colossalai/pipeline/policy/__init__.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List, Optional, Tuple, Type + +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy +from .bert import BertModel, BertModelPolicy + +POLICY_MAP: Dict[Type[Module], Type[Policy]] = { + BertModel: BertModelPolicy, +} + + +def pipeline_parallelize( + model: Module, + stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + if type(model) not in POLICY_MAP: + raise NotImplementedError(f"Policy for {type(model)} not implemented") + policy = POLICY_MAP[type(model)](stage_manager) + return policy.parallelize_model(model) diff --git a/colossalai/pipeline/policy/base.py b/colossalai/pipeline/policy/base.py new file mode 100644 index 000000000..ad595a04b --- /dev/null +++ b/colossalai/pipeline/policy/base.py @@ -0,0 +1,108 @@ +from typing import Any, Dict, List, Optional, Tuple + +from colossalai.lazy import LazyTensor +from torch import Tensor +from torch.nn import Module, Parameter + +from colossalai.pipeline.stage_manager import PipelineStageManager + + +class Policy: + def __init__(self, stage_manager: PipelineStageManager) -> None: + self.stage_manager = stage_manager + + def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]: + """Setup model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers + """ + hold_params = set() + hold_buffers = set() + + def init_layer(layer: Module): + for p in layer.parameters(): + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + hold_params.add(p) + for b in layer.buffers(): + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + hold_buffers.add(b) + + hold_layers = self.get_hold_layers(module) + + for layer in hold_layers: + init_layer(layer) + + hold_params_dict = {} + hold_buffers_dict = {} + + # release other tensors + for n, p in module.named_parameters(): + if p in hold_params: + hold_params_dict[n] = p + else: + if isinstance(p, LazyTensor): + p.materialize() + p.data = p.cuda() + p.storage().resize_(0) + for n, b in module.named_buffers(): + if b in hold_buffers: + hold_buffers_dict[n] = b + else: + if isinstance(b, LazyTensor): + b.materialize() + b.data = b.cuda() + # FIXME(ver217): use meta tensor may be better + b.storage().resize_(0) + return hold_params_dict, hold_buffers_dict + + def replace_forward(self, module: Module) -> None: + """Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict + + Args: + module (Module): _description_ + """ + raise NotImplementedError + + def get_hold_layers(self, module: Module) -> List[Module]: + """Get layers that should be hold in current stage. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of layers that should be hold in current stage + """ + raise NotImplementedError + + def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]: + """Get parameters that should be shared across stages. This method should be implemented by subclass. + + Args: + module (Module): Module to be setup + + Returns: + List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] + """ + raise NotImplementedError + + def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: + """Parallelize model for pipeline parallel + + Args: + module (Module): Module to be setup + + Returns: + Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters + """ + hold_params, hold_buffers = self.setup_model(module) + self.replace_forward(module) + shared_params = self.get_shared_params(module) + return hold_params, hold_buffers, shared_params diff --git a/colossalai/pipeline/policy/bert.py b/colossalai/pipeline/policy/bert.py new file mode 100644 index 000000000..6f912d2c6 --- /dev/null +++ b/colossalai/pipeline/policy/bert.py @@ -0,0 +1,390 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, +) +from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + +logger = logging.get_logger(__name__) + + +def bert_model_forward( + self: BertModel, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + #labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage +): + #TODO: add explaination of the output here. + r""" + encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + """ + # debugging + # preprocess: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.config.is_decoder: + use_cache = use_cache if use_cache is not None else self.config.use_cache + else: + use_cache = False + + if stage_manager.is_first_stage(): + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + batch_size, seq_length = input_shape + device = input_ids.device if input_ids is not None else inputs_embeds.device + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape) + attention_mask = extended_attention_mask + else: + input_shape = hidden_states.size()[:-1] + batch_size, seq_length = input_shape + device = hidden_states.device + + if output_attentions: + logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + output_attentions = False + if output_hidden_states: + logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + output_hidden_states = False + if use_cache: + logger.warning_once('use_cache=True is not supported for pipeline models at the moment.') + use_cache = False + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if attention_mask is None: + attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) + + if token_type_ids is None: + if hasattr(self.embeddings, "token_type_ids"): + buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] + buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) + token_type_ids = buffered_token_type_ids_expanded + else: + token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + hidden_states = hidden_states if hidden_states is not None else None + if stage_manager.is_first_stage(): + hidden_states = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + ) + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + #inherit from bert_layer + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + next_decoder_cache = () if use_cache else None + + #calculate the num_layers + num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages + start_layer = stage_manager.stage * num_layers_per_stage + end_layer = (stage_manager.stage + 1) * num_layers_per_stage + + #layer_outputs + layer_outputs = hidden_states if hidden_states is not None else None + for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer): + if stage_manager.is_first_stage() and idx == 0: + encoder_attention_mask = encoder_extended_attention_mask + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[idx] if head_mask is not None else None + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.encoder.gradient_checkpointing and self.encoder.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(encoder_layer), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + #end of a stage loop + sequence_output = layer_outputs[0] if layer_outputs is not None else None + + if stage_manager.is_last_stage(): + pooled_output = self.pooler(sequence_output) if self.pooler is not None else None + if not return_dict: + return (sequence_output, pooled_output) + layer_outputs[1:] + + #output of non-first and non-last stages: + if not return_dict: + return tuple(v for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] if v is not None) + + #return dict is not supported at this moment + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +# The layer partition policy for bertmodel +class BertModelPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertModel) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + + if self.stage_manager.is_last_stage(): + hold_layers.append(module.pooler) + + return hold_layers + + def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage + + +def bert_for_pretraining_forward( + self: BertForPreTraining, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + next_sentence_label: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + hidden_states: Optional[torch.LongTensor] = None, + stage_manager: Optional[PipelineStageManager] = None, +) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output, pooled_output = outputs[:2] + prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) + + total_loss = None + if labels is not None and next_sentence_label is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) + next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) + total_loss = masked_lm_loss + next_sentence_loss + + if not return_dict: + output = (prediction_scores, seq_relationship_score) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return BertForPreTrainingOutput( + loss=total_loss, + prediction_logits=prediction_scores, + seq_relationship_logits=seq_relationship_score, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class BertForPreTrainingPolicy(Policy): + + def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int): + self.stage_manager = stage_manager + self.layers_per_stage = self.distribute_layers(num_layers, num_stages) + + def get_hold_layers(self, module: BertForPreTraining) -> List[Module]: + """ + get pipeline layers for current stage + """ + hold_layers = [] + if self.stage_manager.is_first_stage(): + hold_layers.append(module.bert.embeddings) + num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage) + hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \ + [self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0: + num_layers_per_stage_accumulated[self.stage_manager.stage]]) + if self.stage_manager.is_last_stage(): + hold_layers.append(module.cls) + + return hold_layers + + def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]: + '''no shared params in bertmodel''' + pass + + def replace_forward(self, module: Module) -> None: + module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager), + module.model) + + def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]: + """ + divide layers into stages + """ + quotient = num_layers // num_stages + remainder = num_layers % num_stages + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages + + # deal with the rest layers + if remainder > 0: + start_position = num_layers // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/pipeline/policy/bloom.py b/colossalai/pipeline/policy/bloom.py new file mode 100644 index 000000000..8dffcd8f9 --- /dev/null +++ b/colossalai/pipeline/policy/bloom.py @@ -0,0 +1,153 @@ +from functools import partial +from types import MethodType +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torch.nn import CrossEntropyLoss, Module +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.bloom.modeling_bloom import BloomModel +from transformers.utils import logging + +from colossalai.pipeline.stage_manager import PipelineStageManager + +from .base import Policy + + +def bloom_model_forward( + self: BloomModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, +) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + if deprecated_arguments.pop("position_ids", False) is not False: + # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None` + warnings.warn( + "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore" + " passing `position_ids`.", + FutureWarning, + ) + if len(deprecated_arguments) > 0: + raise ValueError(f"Got unexpected arguments: {deprecated_arguments}") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.h)) + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape batch_size x num_heads x N x N + # head_mask has shape n_layer x batch x num_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + + hidden_states = self.word_embeddings_layernorm(inputs_embeds) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, use_cache=use_cache, output_attentions=output_attentions) + + return custom_forward + + outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + alibi, + causal_mask, + layer_past, + head_mask[i], + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=causal_mask, + head_mask=head_mask[i], + use_cache=use_cache, + output_attentions=output_attentions, + alibi=alibi, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + # Add last hidden state + hidden_states = self.ln_f(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) diff --git a/tests/test_pipeline/test_policy/test_bert_model.py b/tests/test_pipeline/test_policy/test_bert_model.py new file mode 100644 index 000000000..c92f7f6c3 --- /dev/null +++ b/tests/test_pipeline/test_policy/test_bert_model.py @@ -0,0 +1,112 @@ +import pytest +import torch +import torch.distributed as dist +from transformers.models.bert.modeling_bert import BertModel + +import colossalai +from colossalai.cluster import ProcessGroupMesh +from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward +from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.testing import rerun_if_address_is_in_use, spawn + + +def check_bert_model_forward(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + # print(rank) + + x = torch.randint(0, 1000, (2, 3)) + hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x) + output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('start the training') + else: + attention_mask = torch.ones((2, 12, 3, 3)) + output = bert_model_forward(self=model, + hidden_states=hidden_states, + attention_mask=attention_mask, + stage_manager=stage_manager) + print(output[0].shape) + assert output[0].shape == (2, 3, 768) + print('end the training') + print(output) + + # assert output[1].shape == (2, 768) + + +def check_bert_model_policy(): + model = BertModel.from_pretrained('bert-base-uncased') + DP_DIM, PP_DIM = 0, 1 + DP_SIZE, PP_SIZE = 2, 2 + RANK_TO_COORDINATE = { + 0: (0, 0), + 1: (0, 1), + 2: (1, 0), + 3: (1, 1), + } + PP_RANKS_IN_GROUP = { + 0: [0, 1], + 1: [0, 1], + 2: [2, 3], + 3: [2, 3], + } + pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) + #print(pg_mesh) + + stage_manager = PipelineStageManager(pg_mesh, PP_DIM) + rank = dist.get_rank() + + model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2) + assert model_policy.layers_per_stage == [6, 6] + layers = model_policy.get_hold_layers(model) + for layer in layers: + print(layer) + + +def run_dist_model(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_forward() + + +def run_dist_policy(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost') + check_bert_model_policy() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_forward(): + spawn(run_dist_model, 4) + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_bert_model_policy(): + spawn(run_dist_policy, 4) + + +if __name__ == "__main__": + """test the bert model forward and bert model policy""" + test_bert_model_forward() + test_bert_model_policy() diff --git a/tests/test_pipeline/test_stage_manager.py b/tests/test_pipeline/test_stage_manager.py index be4591d58..67a2e9053 100644 --- a/tests/test_pipeline/test_stage_manager.py +++ b/tests/test_pipeline/test_stage_manager.py @@ -21,7 +21,7 @@ def check_stage_manager(): 1: [0, 1], 2: [2, 3], 3: [2, 3], - } + } pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) rank = dist.get_rank()