[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
This commit is contained in:
Hongxin Liu
2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@@ -1,13 +1,11 @@
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MultipleChoiceModelOutput,
@@ -28,12 +26,11 @@ from transformers.models.bert.modeling_bert import (
BertLMHeadModel,
BertModel,
)
from transformers.utils import ModelOutput, logging
from transformers.utils import logging
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
logger = logging.get_logger(__name__)
@@ -177,6 +174,17 @@ class BertPolicy(Policy):
target_key=BertLMPredictionHead)
return base_policy
def add_lm_prediction_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
method_replacement = {
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead)
return base_policy
def postprocess(self):
return self.model
@@ -240,6 +248,7 @@ class BertForPreTrainingPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForPreTraining
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
return policy
@@ -266,21 +275,13 @@ class BertForPreTrainingPolicy(BertPolicy):
model = self.model
if self.pipeline_stage_manager:
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: model.bert.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@@ -291,6 +292,7 @@ class BertLMHeadModelPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertLMHeadModel
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
return policy
@@ -316,21 +318,13 @@ class BertLMHeadModelPolicy(BertPolicy):
bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
@@ -341,6 +335,7 @@ class BertForMaskedLMPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
mpolicy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForMaskedLM
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
return policy
@@ -366,21 +361,13 @@ class BertForMaskedLMPolicy(BertPolicy):
bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
@@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward(
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
# -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward(
stage_index: Optional[List[int]] = None,
**kwargs,
):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair