mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 17:46:42 +00:00
[shardformer] Support the T5ForTokenClassification model (#5816)
* t5 token, still pytest fail * Resolve T5 Pytest Failure * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typos --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -8,8 +8,15 @@ from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
Seq2SeqLMOutput,
|
||||
Seq2SeqModelOutput,
|
||||
TokenClassifierOutput,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import (
|
||||
T5EncoderModel,
|
||||
T5ForConditionalGeneration,
|
||||
T5ForTokenClassification,
|
||||
T5Model,
|
||||
T5Stack,
|
||||
)
|
||||
from transformers.models.t5.modeling_t5 import T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Stack
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
@@ -582,6 +589,71 @@ class T5PipelineForwards:
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def t5_for_token_classification_forward(
|
||||
self: T5ForTokenClassification,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
position_bias: Optional[torch.Tensor] = None,
|
||||
encoder_decoder_position_bias: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
backward_tensor_keys: Optional[List[str]] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
decoder_starting_stage: Optional[int] = None,
|
||||
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
||||
r"""
|
||||
This function is modified on the basis of transformers.models.t5.modeling_t5.T5ForTokenClassification.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
```"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = T5PipelineForwards.t5_stack_forward(
|
||||
self.transformer.encoder,
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
head_mask=head_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
position_bias=position_bias,
|
||||
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
||||
stage_index=stage_index,
|
||||
decoder_starting_stage=decoder_starting_stage,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[2:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return TokenClassifierOutput(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def get_t5_flash_attention_forward():
|
||||
from transformers.models.t5.modeling_t5 import T5Attention
|
||||
|
@@ -68,6 +68,9 @@ _POLICY_LIST = {
|
||||
file_name="t5", class_name="T5ForConditionalGenerationPolicy"
|
||||
),
|
||||
"transformers.models.t5.modeling_t5.T5EncoderModel": PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),
|
||||
"transformers.models.t5.modeling_t5.T5ForTokenClassification": PolicyLocation(
|
||||
file_name="t5", class_name="T5ForTokenClassificationPolicy"
|
||||
),
|
||||
# GPT2
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2Model": PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel": PolicyLocation(
|
||||
|
@@ -31,7 +31,13 @@ from ..modeling.t5 import (
|
||||
)
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["distribute_t5_layers", "T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
__all__ = [
|
||||
"distribute_t5_layers",
|
||||
"T5ModelPolicy",
|
||||
"T5ForConditionalGenerationPolicy",
|
||||
"T5EncoderPolicy",
|
||||
"T5ForTokenClassificationPolicy",
|
||||
]
|
||||
|
||||
|
||||
class T5BasePolicy(Policy):
|
||||
@@ -312,9 +318,13 @@ class T5BasePolicy(Policy):
|
||||
assert self.pipeline_stage_manager is not None
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
model = self.model
|
||||
encoder = self.model.encoder
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
if self.model.__class__.__name__ == "T5ForTokenClassification":
|
||||
model = self.model.transformer
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
encoder = model.encoder
|
||||
decoder = getattr(model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
num_decoder_layers = len(decoder.block) if decoder else 0
|
||||
@@ -353,7 +363,11 @@ class T5BasePolicy(Policy):
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
|
||||
encoder = self.model.encoder
|
||||
if self.model.__class__.__name__ == "T5ForTokenClassification":
|
||||
encoder = self.model.transformer.encoder
|
||||
else:
|
||||
encoder = self.model.encoder
|
||||
|
||||
decoder = getattr(self.model, "decoder", None)
|
||||
|
||||
num_encoder_layers = len(encoder.block)
|
||||
@@ -542,3 +556,46 @@ class T5EncoderPolicy(T5BasePolicy):
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
return []
|
||||
|
||||
|
||||
class T5ForTokenClassificationPolicy(T5EncoderPolicy):
|
||||
def module_policy(self):
|
||||
from transformers.models.t5.modeling_t5 import T5ForTokenClassification
|
||||
|
||||
policy = super().module_policy()
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
T5ForTokenClassification: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=T5ForTokenClassification,
|
||||
new_forward=T5PipelineForwards.t5_for_token_classification_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
held_layers = super().get_held_layers()
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager.is_last_stage(ignore_chunk=True):
|
||||
held_layers.append(self.model.dropout)
|
||||
held_layers.append(self.model.classifier)
|
||||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
# no shared params for sequence classification model
|
||||
return []
|
||||
|
Reference in New Issue
Block a user