mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-27 04:33:04 +00:00
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer version
This commit is contained in:
@@ -1,13 +1,11 @@
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MultipleChoiceModelOutput,
|
||||
@@ -28,12 +26,11 @@ from transformers.models.bert.modeling_bert import (
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
from transformers.utils import ModelOutput, logging
|
||||
from transformers.utils import logging
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -177,6 +174,17 @@ class BertPolicy(Policy):
|
||||
target_key=BertLMPredictionHead)
|
||||
return base_policy
|
||||
|
||||
def add_lm_prediction_policy(self, base_policy):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
method_replacement = {
|
||||
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
|
||||
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
@@ -240,6 +248,7 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
|
||||
return policy
|
||||
@@ -266,21 +275,13 @@ class BertForPreTrainingPolicy(BertPolicy):
|
||||
model = self.model
|
||||
if self.pipeline_stage_manager:
|
||||
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: model.bert.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
@@ -291,6 +292,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
|
||||
return policy
|
||||
@@ -316,21 +318,13 @@ class BertLMHeadModelPolicy(BertPolicy):
|
||||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
@@ -341,6 +335,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
mpolicy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
|
||||
return policy
|
||||
@@ -366,21 +361,13 @@ class BertForMaskedLMPolicy(BertPolicy):
|
||||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
@@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward(
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
@@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward(
|
||||
stage_index: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
|
Reference in New Issue
Block a user